diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 7d83d3f905cfe..679bb7baf3fdc 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -9,10 +9,10 @@ import ( "net/http" "sort" "strconv" - "sync" "time" "github.com/google/uuid" + "go.uber.org/atomic" "nhooyr.io/websocket" "cdr.dev/slog" @@ -374,19 +374,30 @@ func (api *API) followLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan *databas logger := api.Logger.With(slog.F("job_id", jobID)) var ( - closed = make(chan struct{}) - bufferedLogs = make(chan *database.ProvisionerJobLog, 128) - logMut = &sync.Mutex{} + bufferedLogs = make(chan *database.ProvisionerJobLog, 128) + endOfLogs atomic.Bool + lastSentLogID atomic.Int64 ) + + sendLog := func(log *database.ProvisionerJobLog) { + select { + case bufferedLogs <- log: + logger.Debug(context.Background(), "subscribe buffered log", slog.F("stage", log.Stage)) + lastSentLogID.Store(log.ID) + default: + // If this overflows users could miss logs streaming. This can happen + // we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub, + // so just drop them. + logger.Warn(context.Background(), "provisioner job log overflowing channel") + } + } + closeSubscribe, err := api.Pubsub.Subscribe( provisionerJobLogsChannel(jobID), func(ctx context.Context, message []byte) { - select { - case <-closed: + if endOfLogs.Load() { return - default: } - jlMsg := provisionerJobLogsMessage{} err := json.Unmarshal(message, &jlMsg) if err != nil { @@ -394,6 +405,7 @@ func (api *API) followLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan *databas return } + // CreatedAfter is sent when logs are streaming! if jlMsg.CreatedAfter != 0 { logs, err := api.Database.GetProvisionerLogsByIDBetween(dbauthz.As(ctx, actor), database.GetProvisionerLogsByIDBetweenParams{ JobID: jobID, @@ -403,54 +415,44 @@ func (api *API) followLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan *databas logger.Warn(ctx, "get provisioner logs", slog.Error(err)) return } - for _, log := range logs { - // Sadly we have to use a mutex here because events may be - // handled out of order due to golang goroutine scheduling - // semantics (even though Postgres guarantees ordering of - // notifications). - logMut.Lock() - select { - case <-closed: - logMut.Unlock() + if endOfLogs.Load() { + // An end of logs message came in while we were fetching + // logs or processing them! return - default: } log := log - select { - case bufferedLogs <- &log: - logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage)) - default: - // If this overflows users could miss logs streaming. This can happen - // we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub, - // so just drop them. - logger.Warn(ctx, "provisioner job log overflowing channel") - } - logMut.Unlock() + sendLog(&log) } } + // EndOfLogs is sent when logs are done streaming. + // We don't want to end the stream until we've sent all the logs, + // so we fetch logs after the last ID we've seen and send them! if jlMsg.EndOfLogs { - // This mutex is to guard double-closes. - logMut.Lock() - select { - case <-closed: - logMut.Unlock() + endOfLogs.Store(true) + logs, err := api.Database.GetProvisionerLogsByIDBetween(dbauthz.As(ctx, actor), database.GetProvisionerLogsByIDBetweenParams{ + JobID: jobID, + CreatedAfter: lastSentLogID.Load(), + }) + if err != nil { + logger.Warn(ctx, "get provisioner logs", slog.Error(err)) return - default: + } + for _, log := range logs { + log := log + sendLog(&log) } logger.Debug(ctx, "got End of Logs") bufferedLogs <- nil - logMut.Unlock() } + + lastSentLogID.Store(jlMsg.CreatedAfter) }, ) if err != nil { return nil, nil, err } - return bufferedLogs, func() { - closeSubscribe() - close(closed) - close(bufferedLogs) - }, nil + // We don't need to close the bufferedLogs channel because it will be garbage collected! + return bufferedLogs, closeSubscribe, nil }