diff --git a/coderd/coderd.go b/coderd/coderd.go index 780386c4a5db1..f301265cc5ad7 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1108,6 +1108,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context) (client pro logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name)) logger.Info(ctx, "starting in-memory provisioner daemon") srv, err := provisionerdserver.NewServer( + api.ctx, api.AccessURL, uuid.New(), logger, diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index dd8bed7fef1b5..5afb85565c50b 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -58,6 +58,10 @@ type Options struct { } type server struct { + // lifecycleCtx must be tied to the API server's lifecycle + // as when the API server shuts down, we want to cancel any + // long-running operations. + lifecycleCtx context.Context AccessURL *url.URL ID uuid.UUID Logger slog.Logger @@ -107,6 +111,7 @@ func (t Tags) Valid() error { } func NewServer( + lifecycleCtx context.Context, accessURL *url.URL, id uuid.UUID, logger slog.Logger, @@ -124,7 +129,10 @@ func NewServer( deploymentValues *codersdk.DeploymentValues, options Options, ) (proto.DRPCProvisionerDaemonServer, error) { - // Panic early if pointers are nil + // Fail-fast if pointers are nil + if lifecycleCtx == nil { + return nil, xerrors.New("ctx is nil") + } if quotaCommitter == nil { return nil, xerrors.New("quotaCommitter is nil") } @@ -153,6 +161,7 @@ func NewServer( options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur } return &server{ + lifecycleCtx: lifecycleCtx, AccessURL: accessURL, ID: id, Logger: logger, @@ -1184,16 +1193,21 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) } go func() { for _, wait := range updates { - // Wait for the next potential timeout to occur. Note that we - // can't listen on the context here because we will hang around - // after this function has returned. The s also doesn't - // have a shutdown signal we can listen to. - <-wait - if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil { - s.Logger.Error(ctx, "workspace notification after agent timeout failed", + select { + case <-s.lifecycleCtx.Done(): + // If the server is shutting down, we don't want to wait around. + s.Logger.Debug(ctx, "stopping notifications due to server shutdown", slog.F("workspace_build_id", workspaceBuild.ID), - slog.Error(err), ) + return + case <-wait: + // Wait for the next potential timeout to occur. + if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil { + s.Logger.Error(ctx, "workspace notification after agent timeout failed", + slog.F("workspace_build_id", workspaceBuild.ID), + slog.Error(err), + ) + } } } }() diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 34f3b8377c5d1..db97724c72987 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -1733,6 +1733,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi } srv, err := provisionerdserver.NewServer( + ctx, &url.URL{}, srvID, slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index c74a439e2db87..70f59f40308f0 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -243,6 +243,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name)) logger.Info(ctx, "starting external provisioner daemon") srv, err := provisionerdserver.NewServer( + api.ctx, api.AccessURL, uuid.New(), logger,