From b66707136d47060cb1810886ac801a180e5b8bbb Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 9 Nov 2022 17:52:00 +0000 Subject: [PATCH] fix: prevent races from processing build logs after channel close --- coderd/provisionerjobs.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index dda3c2607fffa..9111f9ca287a2 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -9,6 +9,7 @@ import ( "net/http" "sort" "strconv" + "sync" "time" "github.com/google/uuid" @@ -371,6 +372,7 @@ func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, var ( closed = make(chan struct{}) bufferedLogs = make(chan database.ProvisionerJobLog, 128) + logMut = &sync.Mutex{} ) closeSubscribe, err := api.Pubsub.Subscribe( provisionerJobLogsChannel(jobID), @@ -380,12 +382,14 @@ func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, return default: } + jlMsg := provisionerJobLogsMessage{} err := json.Unmarshal(message, &jlMsg) if err != nil { logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err)) return } + if jlMsg.CreatedAfter != 0 { logs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{ JobID: jobID, @@ -397,6 +401,18 @@ func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, } for _, log := range logs { + // Sadly we have to use a mutex here because events may be + // handled out of order due to golang goroutine scheduling + // semantics (even though Postgres guarantees ordering of + // notifications). + logMut.Lock() + select { + case <-closed: + logMut.Unlock() + return + default: + } + select { case bufferedLogs <- log: logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage)) @@ -406,12 +422,24 @@ func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, // so just drop them. logger.Warn(ctx, "provisioner job log overflowing channel") } + logMut.Unlock() } } + if jlMsg.EndOfLogs { + // This mutex is to guard double-closes. + logMut.Lock() + select { + case <-closed: + logMut.Unlock() + return + default: + } logger.Debug(ctx, "got End of Logs") + close(closed) close(bufferedLogs) + logMut.Unlock() } }, )