Skip to content

Commit b1e4cfe

Browse files
authored
fix pubsub/poll race on provisioner job logs (#2783)
* fix pubsub/poll race on provisioner job logs Signed-off-by: Spike Curtis <spike@coder.com> * only cancel on non-error Signed-off-by: Spike Curtis <spike@coder.com> * Improve logging & comments Signed-off-by: spikecurtis <spike@spikecurtis.com>
1 parent c1b3080 commit b1e4cfe

File tree

4 files changed

+320
-82
lines changed

4 files changed

+320
-82
lines changed

coderd/provisionerdaemons.go

+21-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto.
380380
return nil, xerrors.Errorf("insert job logs: %w", err)
381381
}
382382
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
383-
data, err := json.Marshal(logs)
383+
data, err := json.Marshal(provisionerJobLogsMessage{Logs: logs})
384384
if err != nil {
385385
return nil, xerrors.Errorf("marshal job log: %w", err)
386386
}
@@ -549,6 +549,16 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa
549549
}
550550
case *proto.FailedJob_TemplateImport_:
551551
}
552+
553+
data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
554+
if err != nil {
555+
return nil, xerrors.Errorf("marshal job log: %w", err)
556+
}
557+
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
558+
if err != nil {
559+
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
560+
return nil, xerrors.Errorf("publish end of job logs: %w", err)
561+
}
552562
return &proto.Empty{}, nil
553563
}
554564

@@ -711,6 +721,16 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr
711721
reflect.TypeOf(completed.Type).String())
712722
}
713723

724+
data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
725+
if err != nil {
726+
return nil, xerrors.Errorf("marshal job log: %w", err)
727+
}
728+
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
729+
if err != nil {
730+
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
731+
return nil, xerrors.Errorf("publish end of job logs: %w", err)
732+
}
733+
714734
server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID))
715735
return &proto.Empty{}, nil
716736
}

coderd/provisionerjobs.go

+114-80
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
// The combination of these responses should provide all current logs
2929
// to the consumer, and future logs are streamed in the follow request.
3030
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
31+
logger := api.Logger.With(slog.F("job_id", job.ID))
3132
follow := r.URL.Query().Has("follow")
3233
afterRaw := r.URL.Query().Get("after")
3334
beforeRaw := r.URL.Query().Get("before")
@@ -38,6 +39,37 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
3839
return
3940
}
4041

42+
// if we are following logs, start the subscription before we query the database, so that we don't miss any logs
43+
// between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track
44+
// of processed IDs.
45+
var bufferedLogs <-chan database.ProvisionerJobLog
46+
if follow {
47+
bl, closeFollow, err := api.followLogs(job.ID)
48+
if err != nil {
49+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
50+
Message: "Internal error watching provisioner logs.",
51+
Detail: err.Error(),
52+
})
53+
return
54+
}
55+
defer closeFollow()
56+
bufferedLogs = bl
57+
58+
// Next query the job itself to see if it is complete. If so, the historical query to the database will return
59+
// the full set of logs. It's a little sad to have to query the job again, given that our caller definitely
60+
// has, but we need to query it *after* we start following the pubsub to avoid a race condition where the job
61+
// completes between the prior query and the start of following the pubsub. A more substantial refactor could
62+
// avoid this, but not worth it for one fewer query at this point.
63+
job, err = api.Database.GetProvisionerJobByID(r.Context(), job.ID)
64+
if err != nil {
65+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
66+
Message: "Internal error querying job.",
67+
Detail: err.Error(),
68+
})
69+
return
70+
}
71+
}
72+
4173
var after time.Time
4274
// Only fetch logs created after the time provided.
4375
if afterRaw != "" {
@@ -78,26 +110,27 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
78110
}
79111
}
80112

81-
if !follow {
82-
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
83-
JobID: job.ID,
84-
CreatedAfter: after,
85-
CreatedBefore: before,
113+
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
114+
JobID: job.ID,
115+
CreatedAfter: after,
116+
CreatedBefore: before,
117+
})
118+
if errors.Is(err, sql.ErrNoRows) {
119+
err = nil
120+
}
121+
if err != nil {
122+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
123+
Message: "Internal error fetching provisioner logs.",
124+
Detail: err.Error(),
86125
})
87-
if errors.Is(err, sql.ErrNoRows) {
88-
err = nil
89-
}
90-
if err != nil {
91-
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
92-
Message: "Internal error fetching provisioner logs.",
93-
Detail: err.Error(),
94-
})
95-
return
96-
}
97-
if logs == nil {
98-
logs = []database.ProvisionerJobLog{}
99-
}
126+
return
127+
}
128+
if logs == nil {
129+
logs = []database.ProvisionerJobLog{}
130+
}
100131

132+
if !follow {
133+
logger.Debug(r.Context(), "Finished non-follow job logs")
101134
httpapi.Write(rw, http.StatusOK, convertProvisionerJobLogs(logs))
102135
return
103136
}
@@ -118,82 +151,43 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
118151
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageText)
119152
defer wsNetConn.Close() // Also closes conn.
120153

121-
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
122-
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(job.ID), func(ctx context.Context, message []byte) {
123-
var logs []database.ProvisionerJobLog
124-
err := json.Unmarshal(message, &logs)
125-
if err != nil {
126-
api.Logger.Warn(ctx, fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error()))
127-
return
128-
}
129-
130-
for _, log := range logs {
131-
select {
132-
case bufferedLogs <- log:
133-
api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
134-
default:
135-
// If this overflows users could miss logs streaming. This can happen
136-
// if a database request takes a long amount of time, and we get a lot of logs.
137-
api.Logger.Warn(ctx, "provisioner job log overflowing channel")
138-
}
139-
}
140-
})
141-
if err != nil {
142-
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
143-
Message: "Internal error watching provisioner logs.",
144-
Detail: err.Error(),
145-
})
146-
return
147-
}
148-
defer closeSubscribe()
149-
150-
provisionerJobLogs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{
151-
JobID: job.ID,
152-
CreatedAfter: after,
153-
CreatedBefore: before,
154-
})
155-
if errors.Is(err, sql.ErrNoRows) {
156-
err = nil
157-
}
158-
if err != nil {
159-
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
160-
Message: "Internal error fetching provisioner logs.",
161-
Detail: err.Error(),
162-
})
163-
return
164-
}
154+
logIdsDone := make(map[uuid.UUID]bool)
165155

166156
// The Go stdlib JSON encoder appends a newline character after message write.
167157
encoder := json.NewEncoder(wsNetConn)
168-
for _, provisionerJobLog := range provisionerJobLogs {
158+
for _, provisionerJobLog := range logs {
159+
logIdsDone[provisionerJobLog.ID] = true
169160
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
170161
if err != nil {
171162
return
172163
}
173164
}
165+
if job.CompletedAt.Valid {
166+
// job was complete before we queried the database for historical logs, meaning we got everything. No need
167+
// to stream anything from the bufferedLogs.
168+
return
169+
}
174170

175-
ticker := time.NewTicker(250 * time.Millisecond)
176-
defer ticker.Stop()
177171
for {
178172
select {
179-
case <-r.Context().Done():
180-
api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID))
173+
case <-ctx.Done():
174+
logger.Debug(context.Background(), "job logs context canceled")
181175
return
182-
case log := <-bufferedLogs:
183-
api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
184-
err = encoder.Encode(convertProvisionerJobLog(log))
185-
if err != nil {
176+
case log, ok := <-bufferedLogs:
177+
if !ok {
178+
logger.Debug(context.Background(), "done with published logs")
186179
return
187180
}
188-
case <-ticker.C:
189-
job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID)
190-
if err != nil {
191-
api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String()))
192-
continue
193-
}
194-
if job.CompletedAt.Valid {
195-
api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID))
196-
return
181+
if logIdsDone[log.ID] {
182+
logger.Debug(r.Context(), "subscribe duplicated log",
183+
slog.F("stage", log.Stage))
184+
} else {
185+
logger.Debug(r.Context(), "subscribe encoding log",
186+
slog.F("stage", log.Stage))
187+
err = encoder.Encode(convertProvisionerJobLog(log))
188+
if err != nil {
189+
return
190+
}
197191
}
198192
}
199193
}
@@ -343,3 +337,43 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
343337
func provisionerJobLogsChannel(jobID uuid.UUID) string {
344338
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
345339
}
340+
341+
// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel
342+
type provisionerJobLogsMessage struct {
343+
EndOfLogs bool `json:"end_of_logs,omitempty"`
344+
Logs []database.ProvisionerJobLog `json:"logs,omitempty"`
345+
}
346+
347+
func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
348+
logger := api.Logger.With(slog.F("job_id", jobID))
349+
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
350+
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(jobID),
351+
func(ctx context.Context, message []byte) {
352+
jlMsg := provisionerJobLogsMessage{}
353+
err := json.Unmarshal(message, &jlMsg)
354+
if err != nil {
355+
logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err))
356+
return
357+
}
358+
359+
for _, log := range jlMsg.Logs {
360+
select {
361+
case bufferedLogs <- log:
362+
logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage))
363+
default:
364+
// If this overflows users could miss logs streaming. This can happen
365+
// we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub,
366+
// so just drop them.
367+
logger.Warn(ctx, "provisioner job log overflowing channel")
368+
}
369+
}
370+
if jlMsg.EndOfLogs {
371+
logger.Debug(ctx, "got End of Logs")
372+
close(bufferedLogs)
373+
}
374+
})
375+
if err != nil {
376+
return nil, nil, err
377+
}
378+
return bufferedLogs, closeSubscribe, nil
379+
}

0 commit comments

Comments
 (0)