diff --git a/coderd/coderd.go b/coderd/coderd.go index bfce5a5fb1a88..11fbcf9432f7f 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -223,7 +223,11 @@ func New(options *Options) *API { ) r := chi.NewRouter() + ctx, cancel := context.WithCancel(context.Background()) api := &API{ + ctx: ctx, + cancel: cancel, + ID: uuid.New(), Options: options, RootHandler: r, @@ -669,6 +673,11 @@ func New(options *Options) *API { } type API struct { + // ctx is canceled immediately on shutdown, it can be used to abort + // interruptible tasks. + ctx context.Context + cancel context.CancelFunc + *Options // ID is a uniquely generated ID on initialization. // This is used to associate objects with a specific @@ -703,6 +712,8 @@ type API struct { // Close waits for all WebSocket connections to drain before returning. func (api *API) Close() error { + api.cancel() + api.WebsocketWaitMutex.Lock() api.WebsocketWaitGroup.Wait() api.WebsocketWaitMutex.Unlock() diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 5c14f5ea217a5..f8f40d362e999 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -601,7 +601,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request Valid: true, } disconnectedAt := workspaceAgent.DisconnectedAt - updateConnectionTimes := func() error { + updateConnectionTimes := func(ctx context.Context) error { err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{ ID: workspaceAgent.ID, FirstConnectedAt: firstConnectedAt, @@ -620,15 +620,23 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request } defer func() { + // If connection closed then context will be canceled, try to + // ensure our final update is sent. By waiting at most the agent + // inactive disconnect timeout we ensure that we don't block but + // also guarantee that the agent will be considered disconnected + // by normal status check. + ctx, cancel := context.WithTimeout(api.ctx, api.AgentInactiveDisconnectTimeout) + defer cancel() + disconnectedAt = sql.NullTime{ Time: database.Now(), Valid: true, } - _ = updateConnectionTimes() - _ = api.Pubsub.Publish(watchWorkspaceChannel(build.WorkspaceID), []byte{}) + _ = updateConnectionTimes(ctx) + api.publishWorkspaceUpdate(ctx, build.WorkspaceID) }() - err = updateConnectionTimes() + err = updateConnectionTimes(ctx) if err != nil { _ = conn.Close(websocket.StatusGoingAway, err.Error()) return @@ -668,7 +676,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request Time: database.Now(), Valid: true, } - err = updateConnectionTimes() + err = updateConnectionTimes(ctx) if err != nil { _ = conn.Close(websocket.StatusGoingAway, err.Error()) return