Skip to content

Commit 6875faf

Browse files
authored
fix(coderd/provisionerdserver): pass through api ctx to provisionerdserver (#10259)
Passes through coderd API ctx to provisionerd server so we can cancel workspace updates when API is shutting down.
1 parent 01792f0 commit 6875faf

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

coderd/coderd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context) (client pro
11081108
logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name))
11091109
logger.Info(ctx, "starting in-memory provisioner daemon")
11101110
srv, err := provisionerdserver.NewServer(
1111+
api.ctx,
11111112
api.AccessURL,
11121113
uuid.New(),
11131114
logger,

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ type Options struct {
5858
}
5959

6060
type server struct {
61+
// lifecycleCtx must be tied to the API server's lifecycle
62+
// as when the API server shuts down, we want to cancel any
63+
// long-running operations.
64+
lifecycleCtx context.Context
6165
AccessURL *url.URL
6266
ID uuid.UUID
6367
Logger slog.Logger
@@ -107,6 +111,7 @@ func (t Tags) Valid() error {
107111
}
108112

109113
func NewServer(
114+
lifecycleCtx context.Context,
110115
accessURL *url.URL,
111116
id uuid.UUID,
112117
logger slog.Logger,
@@ -124,7 +129,10 @@ func NewServer(
124129
deploymentValues *codersdk.DeploymentValues,
125130
options Options,
126131
) (proto.DRPCProvisionerDaemonServer, error) {
127-
// Panic early if pointers are nil
132+
// Fail-fast if pointers are nil
133+
if lifecycleCtx == nil {
134+
return nil, xerrors.New("ctx is nil")
135+
}
128136
if quotaCommitter == nil {
129137
return nil, xerrors.New("quotaCommitter is nil")
130138
}
@@ -153,6 +161,7 @@ func NewServer(
153161
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
154162
}
155163
return &server{
164+
lifecycleCtx: lifecycleCtx,
156165
AccessURL: accessURL,
157166
ID: id,
158167
Logger: logger,
@@ -1184,16 +1193,21 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
11841193
}
11851194
go func() {
11861195
for _, wait := range updates {
1187-
// Wait for the next potential timeout to occur. Note that we
1188-
// can't listen on the context here because we will hang around
1189-
// after this function has returned. The s also doesn't
1190-
// have a shutdown signal we can listen to.
1191-
<-wait
1192-
if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil {
1193-
s.Logger.Error(ctx, "workspace notification after agent timeout failed",
1196+
select {
1197+
case <-s.lifecycleCtx.Done():
1198+
// If the server is shutting down, we don't want to wait around.
1199+
s.Logger.Debug(ctx, "stopping notifications due to server shutdown",
11941200
slog.F("workspace_build_id", workspaceBuild.ID),
1195-
slog.Error(err),
11961201
)
1202+
return
1203+
case <-wait:
1204+
// Wait for the next potential timeout to occur.
1205+
if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil {
1206+
s.Logger.Error(ctx, "workspace notification after agent timeout failed",
1207+
slog.F("workspace_build_id", workspaceBuild.ID),
1208+
slog.Error(err),
1209+
)
1210+
}
11971211
}
11981212
}
11991213
}()

coderd/provisionerdserver/provisionerdserver_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,6 +1733,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17331733
}
17341734

17351735
srv, err := provisionerdserver.NewServer(
1736+
ctx,
17361737
&url.URL{},
17371738
srvID,
17381739
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),

enterprise/coderd/provisionerdaemons.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
243243
logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name))
244244
logger.Info(ctx, "starting external provisioner daemon")
245245
srv, err := provisionerdserver.NewServer(
246+
api.ctx,
246247
api.AccessURL,
247248
uuid.New(),
248249
logger,

0 commit comments

Comments
 (0)