diff --git a/coderd/activewebsockets/sockets.go b/coderd/activewebsockets/sockets.go new file mode 100644 index 0000000000000..daff0197bbfcd --- /dev/null +++ b/coderd/activewebsockets/sockets.go @@ -0,0 +1,95 @@ +package activewebsockets + +import ( + "context" + "net/http" + "runtime/pprof" + "sync" + + "nhooyr.io/websocket" + + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/codersdk" +) + +// Active is a helper struct that can be used to track active +// websocket connections. All connections will be closed when the parent +// context is canceled. +type Active struct { + ctx context.Context + cancel func() + + wg sync.WaitGroup +} + +func New(ctx context.Context) *Active { + ctx, cancel := context.WithCancel(ctx) + return &Active{ + ctx: ctx, + cancel: cancel, + } +} + +// Accept accepts a websocket connection and calls f with the connection. +// The function will be tracked by the Active struct and will be +// closed when the parent context is canceled. +// Steps: +// 1. Ensure we are still accepting websocket connections, and not shutting down. +// 2. Add 1 to the wait group. +// 3. Ensure we decrement the wait group when we are done (defer). +// 4. Accept the websocket connection. +// 4a. If there is an error, write the error to the response writer and return. +// 5. Launch go routine to kill websocket if the parent context is canceled. +// 6. Call 'f' with the websocket connection. +func (a *Active) Accept(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) { + // Ensure we are still accepting websocket connections, and not shutting down. + if err := a.ctx.Err(); err != nil { + httpapi.Write(context.Background(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "No longer accepting websocket requests.", + Detail: err.Error(), + }) + return + } + // Ensure we decrement the wait group when we are done. + a.wg.Add(1) + defer a.wg.Done() + + // Accept the websocket connection + conn, err := websocket.Accept(rw, r, options) + if err != nil { + httpapi.Write(context.Background(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept websocket.", + Detail: err.Error(), + }) + return + } + // Always track the connection before allowing the caller to handle it. + // This ensures the connection is closed when the parent context is canceled. + // This new context will end if the parent context is cancelled or if + // the connection is closed. + ctx, cancel := context.WithCancel(a.ctx) + defer cancel() + closeConnOnContext(ctx, conn) + + // Handle the websocket connection + f(conn) +} + +// closeConnOnContext launches a go routine that will watch a given context +// and close a websocket connection if that context is canceled. +func closeConnOnContext(ctx context.Context, conn *websocket.Conn) { + // Labeling the go routine for goroutine dumps/debugging. + go pprof.Do(ctx, pprof.Labels("service", "ActiveWebsockets"), func(ctx context.Context) { + select { + case <-ctx.Done(): + _ = conn.Close(websocket.StatusNormalClosure, "") + } + }) +} + +// Close will close all active websocket connections and wait for them to +// finish. +func (a *Active) Close() { + a.cancel() + a.wg.Wait() +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 0e4b73c3e852c..4797fc043857d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -12,7 +12,6 @@ import ( "path/filepath" "regexp" "strings" - "sync" "sync/atomic" "time" @@ -39,6 +38,7 @@ import ( "github.com/coder/coder/buildinfo" // Used to serve the Swagger endpoint + "github.com/coder/coder/coderd/activewebsockets" _ "github.com/coder/coder/coderd/apidoc" "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/awsidentity" @@ -316,6 +316,7 @@ func New(options *Options) *API { TemplateScheduleStore: options.TemplateScheduleStore, Experiments: experiments, healthCheckGroup: &singleflight.Group[string, *healthcheck.Report]{}, + WebsocketWatch: activewebsockets.New(ctx), } if options.UpdateCheckOptions != nil { api.updateChecker = updatecheck.New( @@ -355,7 +356,7 @@ func New(options *Options) *API { apiRateLimiter := httpmw.RateLimit(options.APIRateLimit, time.Minute) derpHandler := derphttp.Handler(api.DERPServer) - derpHandler, api.derpCloseFunc = tailnet.WithWebsocketSupport(api.DERPServer, derpHandler) + derpHandler = tailnet.WithWebsocketSupport(api.WebsocketWatch.Accept, api.DERPServer, derpHandler) r.Use( httpmw.Recover(api.Logger), @@ -784,9 +785,7 @@ type API struct { siteHandler http.Handler - WebsocketWaitMutex sync.Mutex - WebsocketWaitGroup sync.WaitGroup - derpCloseFunc func() + WebsocketWatch *activewebsockets.Active metricsCache *metricscache.Cache workspaceAgentCache *wsconncache.Cache @@ -803,11 +802,8 @@ type API struct { // Close waits for all WebSocket connections to drain before returning. func (api *API) Close() error { api.cancel() - api.derpCloseFunc() - api.WebsocketWaitMutex.Lock() - api.WebsocketWaitGroup.Wait() - api.WebsocketWaitMutex.Unlock() + api.WebsocketWatch.Close() api.metricsCache.Close() if api.updateChecker != nil { diff --git a/coderd/healthcheck/derp_test.go b/coderd/healthcheck/derp_test.go index fdc313e72bd28..f1c3ba4f34ce5 100644 --- a/coderd/healthcheck/derp_test.go +++ b/coderd/healthcheck/derp_test.go @@ -17,6 +17,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" + "github.com/coder/coder/coderd/activewebsockets" "github.com/coder/coder/coderd/healthcheck" "github.com/coder/coder/tailnet" ) @@ -124,10 +125,15 @@ func TestDERP(t *testing.T) { t.Run("ForceWebsockets", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + derpSrv := derp.NewServer(key.NewNode(), func(format string, args ...any) { t.Logf(format, args...) }) defer derpSrv.Close() - handler, closeHandler := tailnet.WithWebsocketSupport(derpSrv, derphttp.Handler(derpSrv)) - defer closeHandler() + + sockets := activewebsockets.New(ctx) + handler := tailnet.WithWebsocketSupport(sockets.Accept, derpSrv, derphttp.Handler(derpSrv)) + defer sockets.Close() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") == "DERP" { @@ -140,7 +146,6 @@ func TestDERP(t *testing.T) { })) var ( - ctx = context.Background() report = healthcheck.DERPReport{} derpURL, _ = url.Parse(srv.URL) opts = &healthcheck.DERPReportOptions{ diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index e03f5b9ffd28d..9d84374c56c02 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -113,71 +113,61 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job logs = []database.ProvisionerJobLog{} } - api.WebsocketWaitMutex.Lock() - api.WebsocketWaitGroup.Add(1) - api.WebsocketWaitMutex.Unlock() - defer api.WebsocketWaitGroup.Done() - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - go httpapi.Heartbeat(ctx, conn) + api.WebsocketWatch.Accept(rw, r, nil, func(conn *websocket.Conn) { + go httpapi.Heartbeat(ctx, conn) - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) - defer wsNetConn.Close() // Also closes conn. + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() // Also closes conn. - logIdsDone := make(map[int64]bool) + logIdsDone := make(map[int64]bool) - // The Go stdlib JSON encoder appends a newline character after message write. - encoder := json.NewEncoder(wsNetConn) - for _, provisionerJobLog := range logs { - logIdsDone[provisionerJobLog.ID] = true - err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog)) + // The Go stdlib JSON encoder appends a newline character after message write. + encoder := json.NewEncoder(wsNetConn) + for _, provisionerJobLog := range logs { + logIdsDone[provisionerJobLog.ID] = true + err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog)) + if err != nil { + return + } + } + job, err = api.Database.GetProvisionerJobByID(ctx, job.ID) if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching provisioner job.", + Detail: err.Error(), + }) return } - } - job, err = api.Database.GetProvisionerJobByID(ctx, job.ID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching provisioner job.", - Detail: err.Error(), - }) - return - } - if job.CompletedAt.Valid { - // job was complete before we queried the database for historical logs - return - } - - for { - select { - case <-ctx.Done(): - logger.Debug(context.Background(), "job logs context canceled") + if job.CompletedAt.Valid { + // job was complete before we queried the database for historical logs return - case log, ok := <-bufferedLogs: - // A nil log is sent when complete! - if !ok || log == nil { - logger.Debug(context.Background(), "reached the end of published logs") + } + + for { + select { + case <-ctx.Done(): + logger.Debug(context.Background(), "job logs context canceled") return - } - if logIdsDone[log.ID] { - logger.Debug(ctx, "subscribe duplicated log", - slog.F("stage", log.Stage)) - } else { - logger.Debug(ctx, "subscribe encoding log", - slog.F("stage", log.Stage)) - err = encoder.Encode(convertProvisionerJobLog(*log)) - if err != nil { + case log, ok := <-bufferedLogs: + // A nil log is sent when complete! + if !ok || log == nil { + logger.Debug(context.Background(), "reached the end of published logs") return } + if logIdsDone[log.ID] { + logger.Debug(ctx, "subscribe duplicated log", + slog.F("stage", log.Stage)) + } else { + logger.Debug(ctx, "subscribe encoding log", + slog.F("stage", log.Stage)) + err = encoder.Encode(convertProvisionerJobLog(*log)) + if err != nil { + return + } + } } } - } + }) } func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 6ce14dad7689e..6fca371c2e522 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -435,115 +435,105 @@ func (api *API) workspaceAgentStartupLogs(rw http.ResponseWriter, r *http.Reques return } - api.WebsocketWaitMutex.Lock() - api.WebsocketWaitGroup.Add(1) - api.WebsocketWaitMutex.Unlock() - defer api.WebsocketWaitGroup.Done() - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - go httpapi.Heartbeat(ctx, conn) + api.WebsocketWatch.Accept(rw, r, nil, func(conn *websocket.Conn) { + go httpapi.Heartbeat(ctx, conn) - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) - defer wsNetConn.Close() // Also closes conn. + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() // Also closes conn. - // The Go stdlib JSON encoder appends a newline character after message write. - encoder := json.NewEncoder(wsNetConn) - err = encoder.Encode(convertWorkspaceAgentStartupLogs(logs)) - if err != nil { - return - } - if workspaceAgent.LifecycleState == database.WorkspaceAgentLifecycleStateReady { - // The startup script has finished running, so we can close the connection. - return - } - - var ( - bufferedLogs = make(chan []database.WorkspaceAgentStartupLog, 128) - endOfLogs atomic.Bool - lastSentLogID atomic.Int64 - ) - - sendLogs := func(logs []database.WorkspaceAgentStartupLog) { - select { - case bufferedLogs <- logs: - lastSentLogID.Store(logs[len(logs)-1].ID) - default: - logger.Warn(ctx, "workspace agent startup log overflowing channel") + // The Go stdlib JSON encoder appends a newline character after message write. + encoder := json.NewEncoder(wsNetConn) + err = encoder.Encode(convertWorkspaceAgentStartupLogs(logs)) + if err != nil { + return + } + if workspaceAgent.LifecycleState == database.WorkspaceAgentLifecycleStateReady { + // The startup script has finished running, so we can close the connection. + return } - } - closeSubscribe, err := api.Pubsub.Subscribe( - agentsdk.StartupLogsNotifyChannel(workspaceAgent.ID), - func(ctx context.Context, message []byte) { - if endOfLogs.Load() { - return - } - jlMsg := agentsdk.StartupLogsNotifyMessage{} - err := json.Unmarshal(message, &jlMsg) - if err != nil { - logger.Warn(ctx, "invalid startup logs notify message", slog.Error(err)) - return + var ( + bufferedLogs = make(chan []database.WorkspaceAgentStartupLog, 128) + endOfLogs atomic.Bool + lastSentLogID atomic.Int64 + ) + + sendLogs := func(logs []database.WorkspaceAgentStartupLog) { + select { + case bufferedLogs <- logs: + lastSentLogID.Store(logs[len(logs)-1].ID) + default: + logger.Warn(ctx, "workspace agent startup log overflowing channel") } + } - if jlMsg.CreatedAfter != 0 { - logs, err := api.Database.GetWorkspaceAgentStartupLogsAfter(dbauthz.As(ctx, actor), database.GetWorkspaceAgentStartupLogsAfterParams{ - AgentID: workspaceAgent.ID, - CreatedAfter: jlMsg.CreatedAfter, - }) - if err != nil { - logger.Warn(ctx, "failed to get workspace agent startup logs after", slog.Error(err)) + closeSubscribe, err := api.Pubsub.Subscribe( + agentsdk.StartupLogsNotifyChannel(workspaceAgent.ID), + func(ctx context.Context, message []byte) { + if endOfLogs.Load() { return } - sendLogs(logs) - } - - if jlMsg.EndOfLogs { - endOfLogs.Store(true) - logs, err := api.Database.GetWorkspaceAgentStartupLogsAfter(dbauthz.As(ctx, actor), database.GetWorkspaceAgentStartupLogsAfterParams{ - AgentID: workspaceAgent.ID, - CreatedAfter: lastSentLogID.Load(), - }) + jlMsg := agentsdk.StartupLogsNotifyMessage{} + err := json.Unmarshal(message, &jlMsg) if err != nil { - logger.Warn(ctx, "get workspace agent startup logs after", slog.Error(err)) + logger.Warn(ctx, "invalid startup logs notify message", slog.Error(err)) return } - sendLogs(logs) - bufferedLogs <- nil - } - }, - ) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to subscribe to startup logs.", - Detail: err.Error(), - }) - return - } - defer closeSubscribe() - for { - select { - case <-ctx.Done(): - logger.Debug(context.Background(), "job logs context canceled") + if jlMsg.CreatedAfter != 0 { + logs, err := api.Database.GetWorkspaceAgentStartupLogsAfter(dbauthz.As(ctx, actor), database.GetWorkspaceAgentStartupLogsAfterParams{ + AgentID: workspaceAgent.ID, + CreatedAfter: jlMsg.CreatedAfter, + }) + if err != nil { + logger.Warn(ctx, "failed to get workspace agent startup logs after", slog.Error(err)) + return + } + sendLogs(logs) + } + + if jlMsg.EndOfLogs { + endOfLogs.Store(true) + logs, err := api.Database.GetWorkspaceAgentStartupLogsAfter(dbauthz.As(ctx, actor), database.GetWorkspaceAgentStartupLogsAfterParams{ + AgentID: workspaceAgent.ID, + CreatedAfter: lastSentLogID.Load(), + }) + if err != nil { + logger.Warn(ctx, "get workspace agent startup logs after", slog.Error(err)) + return + } + sendLogs(logs) + bufferedLogs <- nil + } + }, + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to subscribe to startup logs.", + Detail: err.Error(), + }) return - case logs, ok := <-bufferedLogs: - // A nil log is sent when complete! - if !ok || logs == nil { - logger.Debug(context.Background(), "reached the end of published logs") - return - } - err = encoder.Encode(convertWorkspaceAgentStartupLogs(logs)) - if err != nil { + } + defer closeSubscribe() + + for { + select { + case <-ctx.Done(): + logger.Debug(context.Background(), "job logs context canceled") return + case logs, ok := <-bufferedLogs: + // A nil log is sent when complete! + if !ok || logs == nil { + logger.Debug(context.Background(), "reached the end of published logs") + return + } + err = encoder.Encode(convertWorkspaceAgentStartupLogs(logs)) + if err != nil { + return + } } } - } + }) } // workspaceAgentPTY spawns a PTY and pipes it over a WebSocket. @@ -559,11 +549,6 @@ func (api *API) workspaceAgentStartupLogs(rw http.ResponseWriter, r *http.Reques func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - api.WebsocketWaitMutex.Lock() - api.WebsocketWaitGroup.Add(1) - api.WebsocketWaitMutex.Unlock() - defer api.WebsocketWaitGroup.Done() - appToken, ok := workspaceapps.ResolveRequest(api.Logger, api.AccessURL, api.WorkspaceAppsProvider, rw, r, workspaceapps.Request{ AccessMethod: workspaceapps.AccessMethodTerminal, BasePath: r.URL.Path, @@ -592,35 +577,26 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { width = 80 } - conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) - defer wsNetConn.Close() // Also closes conn. + api.WebsocketWatch.Accept(rw, r, nil, func(conn *websocket.Conn) { + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + defer wsNetConn.Close() // Also closes conn. - go httpapi.Heartbeat(ctx, conn) + go httpapi.Heartbeat(ctx, conn) - agentConn, release, err := api.workspaceAgentCache.Acquire(appToken.AgentID) - if err != nil { - _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) - return - } - defer release() - ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command")) - if err != nil { - _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err)) - return - } - defer ptNetConn.Close() - agent.Bicopy(ctx, wsNetConn, ptNetConn) + agentConn, release, err := api.workspaceAgentCache.Acquire(appToken.AgentID) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) + return + } + defer release() + ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command")) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err)) + return + } + defer ptNetConn.Close() + agent.Bicopy(ctx, wsNetConn, ptNetConn) + }) } // @Summary Get listening ports for workspace agent @@ -816,10 +792,6 @@ func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - api.WebsocketWaitMutex.Lock() - api.WebsocketWaitGroup.Add(1) - api.WebsocketWaitMutex.Unlock() - defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) if err != nil { @@ -883,189 +855,182 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request return } - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } + api.WebsocketWatch.Accept(rw, r, nil, func(conn *websocket.Conn) { + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + defer wsNetConn.Close() - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) - defer wsNetConn.Close() + // We use a custom heartbeat routine here instead of `httpapi.Heartbeat` + // because we want to log the agent's last ping time. + lastPing := time.Now() // Since the agent initiated the request, assume it's alive. + var pingMu sync.Mutex + go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) { + // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? + t := time.NewTicker(api.AgentConnectionUpdateFrequency) + defer t.Stop() - // We use a custom heartbeat routine here instead of `httpapi.Heartbeat` - // because we want to log the agent's last ping time. - lastPing := time.Now() // Since the agent initiated the request, assume it's alive. - var pingMu sync.Mutex - go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) { - // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? - t := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer t.Stop() + for { + select { + case <-t.C: + case <-ctx.Done(): + return + } - for { - select { - case <-t.C: - case <-ctx.Done(): - return + // We don't need a context that times out here because the ping will + // eventually go through. If the context times out, then other + // websocket read operations will receive an error, obfuscating the + // actual problem. + err := conn.Ping(ctx) + if err != nil { + return + } + pingMu.Lock() + lastPing = time.Now() + pingMu.Unlock() } + }) - // We don't need a context that times out here because the ping will - // eventually go through. If the context times out, then other - // websocket read operations will receive an error, obfuscating the - // actual problem. - err := conn.Ping(ctx) - if err != nil { - return + firstConnectedAt := workspaceAgent.FirstConnectedAt + if !firstConnectedAt.Valid { + firstConnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, } - pingMu.Lock() - lastPing = time.Now() - pingMu.Unlock() } - }) - - firstConnectedAt := workspaceAgent.FirstConnectedAt - if !firstConnectedAt.Valid { - firstConnectedAt = sql.NullTime{ + lastConnectedAt := sql.NullTime{ Time: database.Now(), Valid: true, } - } - lastConnectedAt := sql.NullTime{ - Time: database.Now(), - Valid: true, - } - disconnectedAt := workspaceAgent.DisconnectedAt - updateConnectionTimes := func(ctx context.Context) error { - err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ - ID: workspaceAgent.ID, - FirstConnectedAt: firstConnectedAt, - LastConnectedAt: lastConnectedAt, - DisconnectedAt: disconnectedAt, - UpdatedAt: database.Now(), - LastConnectedReplicaID: uuid.NullUUID{ - UUID: api.ID, - Valid: true, - }, - }) - if err != nil { - return err - } - return nil - } - - defer func() { - // If connection closed then context will be canceled, try to - // ensure our final update is sent. By waiting at most the agent - // inactive disconnect timeout we ensure that we don't block but - // also guarantee that the agent will be considered disconnected - // by normal status check. - // - // Use a system context as the agent has disconnected and that token - // may no longer be valid. - //nolint:gocritic - ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout) - defer cancel() - - // Only update timestamp if the disconnect is new. - if !disconnectedAt.Valid { - disconnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, + disconnectedAt := workspaceAgent.DisconnectedAt + updateConnectionTimes := func(ctx context.Context) error { + err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ + ID: workspaceAgent.ID, + FirstConnectedAt: firstConnectedAt, + LastConnectedAt: lastConnectedAt, + DisconnectedAt: disconnectedAt, + UpdatedAt: database.Now(), + LastConnectedReplicaID: uuid.NullUUID{ + UUID: api.ID, + Valid: true, + }, + }) + if err != nil { + return err } + return nil } - err := updateConnectionTimes(ctx) - if err != nil { - // This is a bug with unit tests that cancel the app context and - // cause this error log to be generated. We should fix the unit tests - // as this is a valid log. + + defer func() { + // If connection closed then context will be canceled, try to + // ensure our final update is sent. By waiting at most the agent + // inactive disconnect timeout we ensure that we don't block but + // also guarantee that the agent will be considered disconnected + // by normal status check. // - // The pq error occurs when the server is shutting down. - if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) { - api.Logger.Error(ctx, "failed to update agent disconnect time", - slog.Error(err), - slog.F("workspace", build.WorkspaceID), - ) + // Use a system context as the agent has disconnected and that token + // may no longer be valid. + //nolint:gocritic + ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout) + defer cancel() + + // Only update timestamp if the disconnect is new. + if !disconnectedAt.Valid { + disconnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } } + err := updateConnectionTimes(ctx) + if err != nil { + // This is a bug with unit tests that cancel the app context and + // cause this error log to be generated. We should fix the unit tests + // as this is a valid log. + // + // The pq error occurs when the server is shutting down. + if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) { + api.Logger.Error(ctx, "failed to update agent disconnect time", + slog.Error(err), + slog.F("workspace", build.WorkspaceID), + ) + } + } + api.publishWorkspaceUpdate(ctx, build.WorkspaceID) + }() + + err = updateConnectionTimes(ctx) + if err != nil { + _ = conn.Close(websocket.StatusGoingAway, err.Error()) + return } api.publishWorkspaceUpdate(ctx, build.WorkspaceID) - }() - - err = updateConnectionTimes(ctx) - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, err.Error()) - return - } - api.publishWorkspaceUpdate(ctx, build.WorkspaceID) - // End span so we don't get long lived trace data. - tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx)) - // Ignore all trace spans after this. - ctx = trace.ContextWithSpan(ctx, tracing.NoopSpan) + // End span so we don't get long lived trace data. + tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx)) + // Ignore all trace spans after this. + ctx = trace.ContextWithSpan(ctx, tracing.NoopSpan) - api.Logger.Info(ctx, "accepting agent", slog.F("agent", workspaceAgent)) + api.Logger.Info(ctx, "accepting agent", slog.F("agent", workspaceAgent)) - defer conn.Close(websocket.StatusNormalClosure, "") + defer conn.Close(websocket.StatusNormalClosure, "") - closeChan := make(chan struct{}) - go func() { - defer close(closeChan) - err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID, - fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name), - ) - if err != nil { - api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err)) - _ = conn.Close(websocket.StatusInternalError, err.Error()) - return - } - }() - ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer ticker.Stop() - for { - select { - case <-closeChan: - return - case <-ticker.C: - } + closeChan := make(chan struct{}) + go func() { + defer close(closeChan) + err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID, + fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name), + ) + if err != nil { + api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return + } + }() + ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) + defer ticker.Stop() + for { + select { + case <-closeChan: + return + case <-ticker.C: + } - pingMu.Lock() - lastPing := lastPing - pingMu.Unlock() + pingMu.Lock() + lastPing := lastPing + pingMu.Unlock() - var connectionStatusChanged bool - if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout { - if !disconnectedAt.Valid { - connectionStatusChanged = true - disconnectedAt = sql.NullTime{ + var connectionStatusChanged bool + if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout { + if !disconnectedAt.Valid { + connectionStatusChanged = true + disconnectedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + } + } else { + connectionStatusChanged = disconnectedAt.Valid + // TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it? + disconnectedAt = sql.NullTime{} + lastConnectedAt = sql.NullTime{ Time: database.Now(), Valid: true, } } - } else { - connectionStatusChanged = disconnectedAt.Valid - // TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it? - disconnectedAt = sql.NullTime{} - lastConnectedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, + err = updateConnectionTimes(ctx) + if err != nil { + _ = conn.Close(websocket.StatusGoingAway, err.Error()) + return + } + if connectionStatusChanged { + api.publishWorkspaceUpdate(ctx, build.WorkspaceID) + } + err := ensureLatestBuild() + if err != nil { + // Disconnect agents that are no longer valid. + _ = conn.Close(websocket.StatusGoingAway, "") + return } } - err = updateConnectionTimes(ctx) - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, err.Error()) - return - } - if connectionStatusChanged { - api.publishWorkspaceUpdate(ctx, build.WorkspaceID) - } - err := ensureLatestBuild() - if err != nil { - // Disconnect agents that are no longer valid. - _ = conn.Close(websocket.StatusGoingAway, "") - return - } - } + }) } // workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates. @@ -1096,31 +1061,21 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R } } - api.WebsocketWaitMutex.Lock() - api.WebsocketWaitGroup.Add(1) - api.WebsocketWaitMutex.Unlock() - defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgentParam(r) - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) - defer wsNetConn.Close() + api.WebsocketWatch.Accept(rw, r, nil, func(conn *websocket.Conn) { + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + defer wsNetConn.Close() - go httpapi.Heartbeat(ctx, conn) + go httpapi.Heartbeat(ctx, conn) - defer conn.Close(websocket.StatusNormalClosure, "") - err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID) - if err != nil { - _ = conn.Close(websocket.StatusInternalError, err.Error()) - return - } + defer conn.Close(websocket.StatusNormalClosure, "") + err := (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return + } + }) } func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp { diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 27573014edf88..218b613bf3802 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -185,71 +185,60 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) return } - api.AGPL.WebsocketWaitMutex.Lock() - api.AGPL.WebsocketWaitGroup.Add(1) - api.AGPL.WebsocketWaitMutex.Unlock() - defer api.AGPL.WebsocketWaitGroup.Done() - - conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + api.AGPL.WebsocketWatch.Accept(rw, r, &websocket.AcceptOptions{ // Need to disable compression to avoid a data-race. CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Internal error accepting websocket connection.", - Detail: err.Error(), + }, func(conn *websocket.Conn) { + // Align with the frame size of yamux. + conn.SetReadLimit(256 * 1024) + + // Multiplexes the incoming connection using yamux. + // This allows multiple function calls to occur over + // the same connection. + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + defer wsNetConn.Close() + session, err := yamux.Server(wsNetConn, config) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err)) + return + } + mux := drpcmux.New() + err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ + AccessURL: api.AccessURL, + GitAuthConfigs: api.GitAuthConfigs, + OIDCConfig: api.OIDCConfig, + ID: daemon.ID, + Database: api.Database, + Pubsub: api.Pubsub, + Provisioners: daemon.Provisioners, + Telemetry: api.Telemetry, + Auditor: &api.AGPL.Auditor, + TemplateScheduleStore: api.AGPL.TemplateScheduleStore, + Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), + Tags: rawTags, }) - return - } - // Align with the frame size of yamux. - conn.SetReadLimit(256 * 1024) - - // Multiplexes the incoming connection using yamux. - // This allows multiple function calls to occur over - // the same connection. - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) - defer wsNetConn.Close() - session, err := yamux.Server(wsNetConn, config) - if err != nil { - _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err)) - return - } - mux := drpcmux.New() - err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ - AccessURL: api.AccessURL, - GitAuthConfigs: api.GitAuthConfigs, - OIDCConfig: api.OIDCConfig, - ID: daemon.ID, - Database: api.Database, - Pubsub: api.Pubsub, - Provisioners: daemon.Provisioners, - Telemetry: api.Telemetry, - Auditor: &api.AGPL.Auditor, - TemplateScheduleStore: api.AGPL.TemplateScheduleStore, - Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), - Tags: rawTags, - }) - if err != nil { - _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err)) - return - } - server := drpcserver.NewWithOptions(mux, drpcserver.Options{ - Log: func(err error) { - if xerrors.Is(err, io.EOF) { - return - } - api.Logger.Debug(ctx, "drpc server error", slog.Error(err)) - }, + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err)) + return + } + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + api.Logger.Debug(ctx, "drpc server error", slog.Error(err)) + }, + }) + err = server.Serve(ctx, session) + if err != nil && !xerrors.Is(err, io.EOF) { + api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) + return + } + _ = conn.Close(websocket.StatusGoingAway, "") }) - err = server.Serve(ctx, session) - if err != nil && !xerrors.Is(err, io.EOF) { - api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err)) - _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) - return - } - _ = conn.Close(websocket.StatusGoingAway, "") } func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.ProvisionerDaemon { diff --git a/tailnet/derp.go b/tailnet/derp.go index 6c8e363e91e29..7d84c98f73bed 100644 --- a/tailnet/derp.go +++ b/tailnet/derp.go @@ -2,71 +2,52 @@ package tailnet import ( "bufio" - "context" - "log" "net/http" "strings" - "sync" "nhooyr.io/websocket" "tailscale.com/derp" "tailscale.com/net/wsconn" ) +type HandleWebsocket func(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) + // WithWebsocketSupport returns an http.Handler that upgrades // connections to the "derp" subprotocol to WebSockets and // passes them to the DERP server. // Taken from: https://github.com/tailscale/tailscale/blob/e3211ff88ba85435f70984cf67d9b353f3d650d8/cmd/derper/websocket.go#L21 -func WithWebsocketSupport(s *derp.Server, base http.Handler) (http.Handler, func()) { - var mu sync.Mutex - var waitGroup sync.WaitGroup - ctx, cancelFunc := context.WithCancel(context.Background()) - +// The accept function is used to accept the websocket connection and allows the caller to +// also affect the lifecycle of the websocket connection. (Eg. to close the connection on shutdown) +func WithWebsocketSupport(accept HandleWebsocket, s *derp.Server, base http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - up := strings.ToLower(r.Header.Get("Upgrade")) + up := strings.ToLower(r.Header.Get("Upgrade")) - // Very early versions of Tailscale set "Upgrade: WebSocket" but didn't actually - // speak WebSockets (they still assumed DERP's binary framing). So to distinguish - // clients that actually want WebSockets, look for an explicit "derp" subprotocol. - if up != "websocket" || !strings.Contains(r.Header.Get("Sec-Websocket-Protocol"), "derp") { - base.ServeHTTP(w, r) - return - } + // Very early versions of Tailscale set "Upgrade: WebSocket" but didn't actually + // speak WebSockets (they still assumed DERP's binary framing). So to distinguish + // clients that actually want WebSockets, look for an explicit "derp" subprotocol. + if up != "websocket" || !strings.Contains(r.Header.Get("Sec-Websocket-Protocol"), "derp") { + base.ServeHTTP(w, r) + return + } - mu.Lock() - if ctx.Err() != nil { - mu.Unlock() + accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"derp"}, + OriginPatterns: []string{"*"}, + // Disable compression because we transmit WireGuard messages that + // are not compressible. + // Additionally, Safari has a broken implementation of compression + // (see https://github.com/nhooyr/websocket/issues/218) that makes + // enabling it actively harmful. + CompressionMode: websocket.CompressionDisabled, + }, func(conn *websocket.Conn) { + defer conn.Close(websocket.StatusInternalError, "closing") + if conn.Subprotocol() != "derp" { + conn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") return } - waitGroup.Add(1) - mu.Unlock() - defer waitGroup.Done() - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"derp"}, - OriginPatterns: []string{"*"}, - // Disable compression because we transmit WireGuard messages that - // are not compressible. - // Additionally, Safari has a broken implementation of compression - // (see https://github.com/nhooyr/websocket/issues/218) that makes - // enabling it actively harmful. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - log.Printf("websocket.Accept: %v", err) - return - } - defer c.Close(websocket.StatusInternalError, "closing") - if c.Subprotocol() != "derp" { - c.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") - return - } - wc := wsconn.NetConn(ctx, c, websocket.MessageBinary) + wc := wsconn.NetConn(r.Context(), conn, websocket.MessageBinary) brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc)) - s.Accept(ctx, wc, brw, r.RemoteAddr) - }), func() { - cancelFunc() - mu.Lock() - waitGroup.Wait() - mu.Unlock() - } + s.Accept(r.Context(), wc, brw, r.RemoteAddr) + }) + }) } diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 482c1232e258a..eb5155f0a42de 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -1,6 +1,7 @@ package tailnettest import ( + "context" "crypto/tls" "fmt" "html" @@ -18,6 +19,7 @@ import ( "tailscale.com/types/nettype" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/activewebsockets" "github.com/coder/coder/tailnet" ) @@ -71,8 +73,12 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap { logf := tailnet.Logger(slogtest.Make(t, nil)) d := derp.NewServer(key.NewNode(), logf) handler := derphttp.Handler(d) - var closeFunc func() - handler, closeFunc = tailnet.WithWebsocketSupport(d, handler) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sockets := activewebsockets.New(ctx) + + handler = tailnet.WithWebsocketSupport(sockets.Accept, d, handler) server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/derp" { handler.ServeHTTP(w, r) @@ -91,7 +97,7 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap { t.Cleanup(func() { server.CloseClientConnections() server.Close() - closeFunc() + sockets.Close() d.Close() })