From bcf1d8f1107314c79d6ccf7f929388470793e7e1 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 27 Jun 2022 12:54:38 -0700 Subject: [PATCH 1/5] provisionerd sends failed or complete last Signed-off-by: Spike Curtis --- agent/reaper/reaper_stub.go | 2 +- coderd/coderdtest/coderdtest.go | 1 + coderd/provisionerdaemons.go | 12 + coderd/provisionerjobs.go | 4 + provisionerd/provisionerd.go | 855 +++-------------------------- provisionerd/provisionerd_test.go | 146 ++++- provisionerd/runner.go | 872 ++++++++++++++++++++++++++++++ pty/ptytest/ptytest.go | 2 +- 8 files changed, 1096 insertions(+), 798 deletions(-) create mode 100644 provisionerd/runner.go diff --git a/agent/reaper/reaper_stub.go b/agent/reaper/reaper_stub.go index 538a7db71887a..8cd87ab0bf3a7 100644 --- a/agent/reaper/reaper_stub.go +++ b/agent/reaper/reaper_stub.go @@ -7,6 +7,6 @@ func IsInitProcess() bool { return false } -func ForkReap(opt ...Option) error { +func ForkReap(_ ...Option) error { return nil } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index d06007c50c3e4..db0d30ff8b23a 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -383,6 +383,7 @@ func UpdateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID // AwaitTemplateImportJob awaits for an import job to reach completed status. func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid.UUID) codersdk.TemplateVersion { + t.Logf("waiting for template version job %s", version) var templateVersion codersdk.TemplateVersion require.Eventually(t, func() bool { var err error diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 7a013e15d755d..626483994ddd9 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -331,6 +331,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto. if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } + server.Logger.Debug(ctx, "UpdateJob starting", slog.F("job_id", parsedID)) job, err := server.Database.GetProvisionerJobByID(ctx, parsedID) if err != nil { return nil, xerrors.Errorf("get job: %w", err) @@ -368,19 +369,27 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto. insertParams.Stage = append(insertParams.Stage, log.Stage) insertParams.Source = append(insertParams.Source, logSource) insertParams.Output = append(insertParams.Output, log.Output) + server.Logger.Debug(ctx, "job log", + slog.F("job_id", parsedID), + slog.F("stage", log.Stage), + slog.F("output", log.Output)) } logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams) if err != nil { + server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("insert job logs: %w", err) } + server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) data, err := json.Marshal(logs) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } err = server.Pubsub.Publish(provisionerJobLogsChannel(parsedID), data) if err != nil { + server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("publish job log: %w", err) } + server.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID)) } if len(request.Readme) > 0 { @@ -488,6 +497,7 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } + server.Logger.Debug(ctx, "FailJob starting", slog.F("job_id", jobID)) job, err := server.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get provisioner job: %w", err) @@ -547,6 +557,7 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } + server.Logger.Debug(ctx, "CompleteJob starting", slog.F("job_id", jobID)) job, err := server.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get job by id: %w", err) @@ -699,6 +710,7 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr reflect.TypeOf(completed.Type).String()) } + server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID)) return &proto.Empty{}, nil } diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 38d1efeae43df..8b163412f0ff4 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -130,6 +130,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job for _, log := range logs { select { case bufferedLogs <- log: + api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) default: // If this overflows users could miss logs streaming. This can happen // if a database request takes a long amount of time, and we get a lot of logs. @@ -176,8 +177,10 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job for { select { case <-r.Context().Done(): + api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID)) return case log := <-bufferedLogs: + api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) err = encoder.Encode(convertProvisionerJobLog(log)) if err != nil { return @@ -189,6 +192,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job continue } if job.CompletedAt.Valid { + api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID)) return } } diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index ad69bfc2bbfb9..877f03b3e5bec 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -1,21 +1,14 @@ package provisionerd import ( - "archive/tar" - "bytes" "context" "errors" "fmt" "io" - "os" - "path" - "path/filepath" - "reflect" "strings" "sync" "time" - "github.com/google/uuid" "github.com/hashicorp/yamux" "github.com/spf13/afero" "go.uber.org/atomic" @@ -79,13 +72,8 @@ func New(clientDialer Dialer, opts *Options) *Server { closeCancel: ctxCancel, shutdown: make(chan struct{}), - - jobRunning: make(chan struct{}), - jobFailed: *atomic.NewBool(true), } - // Start off with a closed channel so - // isRunningJob() returns properly. - close(daemon.jobRunning) + go daemon.connect(ctx) return daemon } @@ -96,22 +84,13 @@ type Server struct { clientDialer Dialer clientValue atomic.Value - // Locked when closing the daemon. - closeMutex sync.Mutex + // Locked when closing the daemon, shutting down, or starting a new job. + mutex sync.Mutex closeContext context.Context closeCancel context.CancelFunc closeError error - - shutdownMutex sync.Mutex - shutdown chan struct{} - - // Locked when acquiring or failing a job. - jobMutex sync.Mutex - jobID string - jobRunningMutex sync.Mutex - jobRunning chan struct{} - jobFailed atomic.Bool - jobCancel context.CancelFunc + shutdown chan struct{} + activeJob jobRunner } // Connect establishes a connection to coderd. @@ -192,9 +171,13 @@ func (p *Server) client() (proto.DRPCProvisionerDaemonClient, bool) { return client, ok } +// isRunningJob returns true if a job is running. Caller must hold the mutex. func (p *Server) isRunningJob() bool { + if p.activeJob == nil { + return false + } select { - case <-p.jobRunning: + case <-p.activeJob.isDone(): return false default: return true @@ -203,8 +186,8 @@ func (p *Server) isRunningJob() bool { // Locks a job in the database, and runs it! func (p *Server) acquireJob(ctx context.Context) { - p.jobMutex.Lock() - defer p.jobMutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if p.isClosed() { return } @@ -235,775 +218,78 @@ func (p *Server) acquireJob(ctx context.Context) { if job.JobId == "" { return } - ctx, p.jobCancel = context.WithCancel(ctx) - p.jobRunningMutex.Lock() - p.jobRunning = make(chan struct{}) - p.jobRunningMutex.Unlock() - p.jobFailed.Store(false) - p.jobID = job.JobId - p.opts.Logger.Info(context.Background(), "acquired job", slog.F("initiator_username", job.UserName), slog.F("provisioner", job.Provisioner), - slog.F("id", job.JobId), + slog.F("job_id", job.JobId), ) - go p.runJob(ctx, job) -} - -func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) { - shutdown, shutdownCancel := context.WithCancel(ctx) - defer shutdownCancel() - - complete, completeCancel := context.WithCancel(ctx) - defer completeCancel() - go func() { - ticker := time.NewTicker(p.opts.UpdateInterval) - defer ticker.Stop() - for { - select { - case <-p.closeContext.Done(): - return - case <-ctx.Done(): - return - case <-complete.Done(): - return - case <-p.shutdown: - p.opts.Logger.Info(ctx, "attempting graceful cancelation") - shutdownCancel() - return - case <-ticker.C: - } - client, ok := p.client() - if !ok { - continue - } - resp, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - }) - if errors.Is(err, yamux.ErrSessionShutdown) || errors.Is(err, io.EOF) { - continue - } - if err != nil { - p.failActiveJobf("send periodic update: %s", err) - return - } - if !resp.Canceled { - continue - } - p.opts.Logger.Info(ctx, "attempting graceful cancelation") - shutdownCancel() - // Hard-cancel the job after a minute of pending cancelation. - timer := time.NewTimer(p.opts.ForceCancelInterval) - select { - case <-timer.C: - p.failActiveJobf("cancelation timed out") - return - case <-ctx.Done(): - timer.Stop() - return - } - } - }() - defer func() { - // Cleanup the work directory after execution. - for attempt := 0; attempt < 5; attempt++ { - err := p.opts.Filesystem.RemoveAll(p.opts.WorkDirectory) - if err != nil { - // On Windows, open files cannot be removed. - // When the provisioner daemon is shutting down, - // it may take a few milliseconds for processes to exit. - // See: https://github.com/golang/go/issues/50510 - p.opts.Logger.Debug(ctx, "failed to clean work directory; trying again", slog.Error(err)) - time.Sleep(250 * time.Millisecond) - continue - } - p.opts.Logger.Debug(ctx, "cleaned up work directory", slog.Error(err)) - break - } - - close(p.jobRunning) - }() - // It's safe to cast this ProvisionerType. This data is coming directly from coderd. - provisioner, hasProvisioner := p.opts.Provisioners[job.Provisioner] - if !hasProvisioner { - p.failActiveJobf("provisioner %q not registered", job.Provisioner) - return - } - - err := p.opts.Filesystem.MkdirAll(p.opts.WorkDirectory, 0700) - if err != nil { - p.failActiveJobf("create work directory %q: %s", p.opts.WorkDirectory, err) - return - } - - client, ok := p.client() - if !ok { - p.failActiveJobf("client disconnected") - return - } - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Setting up", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - - p.opts.Logger.Info(ctx, "unpacking template source archive", slog.F("size_bytes", len(job.TemplateSourceArchive))) - reader := tar.NewReader(bytes.NewBuffer(job.TemplateSourceArchive)) - for { - header, err := reader.Next() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - p.failActiveJobf("read template source archive: %s", err) - return - } - // #nosec - headerPath := filepath.Join(p.opts.WorkDirectory, header.Name) - if !strings.HasPrefix(headerPath, filepath.Clean(p.opts.WorkDirectory)) { - p.failActiveJobf("tar attempts to target relative upper directory") - return - } - mode := header.FileInfo().Mode() - if mode == 0 { - mode = 0600 - } - switch header.Typeflag { - case tar.TypeDir: - err = p.opts.Filesystem.MkdirAll(headerPath, mode) - if err != nil { - p.failActiveJobf("mkdir %q: %s", headerPath, err) - return - } - p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", headerPath)) - case tar.TypeReg: - file, err := p.opts.Filesystem.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode) - if err != nil { - p.failActiveJobf("create file %q (mode %s): %s", headerPath, mode, err) - return - } - // Max file size of 10MiB. - size, err := io.CopyN(file, reader, 10<<20) - if errors.Is(err, io.EOF) { - err = nil - } - if err != nil { - _ = file.Close() - p.failActiveJobf("copy file %q: %s", headerPath, err) - return - } - err = file.Close() - if err != nil { - p.failActiveJobf("close file %q: %s", headerPath, err) - return - } - p.opts.Logger.Debug(context.Background(), "extracted file", - slog.F("size_bytes", size), - slog.F("path", headerPath), - slog.F("mode", mode), - ) - } - } - - switch jobType := job.Type.(type) { - case *proto.AcquiredJob_TemplateImport_: - p.opts.Logger.Debug(context.Background(), "acquired job is template import") - - p.runReadmeParse(ctx, job) - p.runTemplateImport(ctx, shutdown, provisioner, job) - case *proto.AcquiredJob_TemplateDryRun_: - p.opts.Logger.Debug(context.Background(), "acquired job is template dry-run", - slog.F("workspace_name", jobType.TemplateDryRun.Metadata.WorkspaceName), - slog.F("parameters", jobType.TemplateDryRun.ParameterValues), - ) - p.runTemplateDryRun(ctx, shutdown, provisioner, job) - case *proto.AcquiredJob_WorkspaceBuild_: - p.opts.Logger.Debug(context.Background(), "acquired job is workspace provision", - slog.F("workspace_name", jobType.WorkspaceBuild.WorkspaceName), - slog.F("state_length", len(jobType.WorkspaceBuild.State)), - slog.F("parameters", jobType.WorkspaceBuild.ParameterValues), - ) - - p.runWorkspaceBuild(ctx, shutdown, provisioner, job) - default: - p.failActiveJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(job.Type).String()) - return - } - - client, ok = p.client() + provisioner, ok := p.opts.Provisioners[job.Provisioner] if !ok { - return - } - // Ensure the job is still running to output. - // It's possible the job has failed. - if p.isRunningJob() { - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Cleaning Up", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, + err := p.failJob(ctx, &proto.FailedJob{ + JobId: job.JobId, + Error: fmt.Sprintf("no provisioner %s", job.Provisioner), }) if err != nil { - p.failActiveJobf("write log: %s", err) - return + p.opts.Logger.Error(context.Background(), "failed to call FailJob", + slog.F("job_id", job.JobId), slog.Error(err)) } - - p.opts.Logger.Info(context.Background(), "completed job", slog.F("id", job.JobId)) - } -} - -// ReadmeFile is the location we look for to extract documentation from template -// versions. -const ReadmeFile = "README.md" - -func (p *Server) runReadmeParse(ctx context.Context, job *proto.AcquiredJob) { - client, ok := p.client() - if !ok { - p.failActiveJobf("client disconnected") - return - } - - fi, err := afero.ReadFile(p.opts.Filesystem, path.Join(p.opts.WorkDirectory, ReadmeFile)) - if err != nil { - _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_DEBUG, - Stage: "No README.md provided", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - } - - return - } - - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Adding README.md...", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - Readme: fi, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) return } + p.activeJob = newRunner(job, p, p.opts.Logger, p.opts.Filesystem, p.opts.WorkDirectory, provisioner, + p.opts.UpdateInterval, p.opts.ForceCancelInterval) + go p.activeJob.start() } -func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { - client, ok := p.client() - if !ok { - p.failActiveJobf("client disconnected") - return - } - - // Parse parameters and update the job with the parameter specs - _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Parsing template parameters", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - parameterSchemas, err := p.runTemplateImportParse(ctx, provisioner, job) - if err != nil { - p.failActiveJobf("run parse: %s", err) - return - } - updateResponse, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - ParameterSchemas: parameterSchemas, - }) - if err != nil { - p.failActiveJobf("update job: %s", err) - return - } - - valueByName := map[string]*sdkproto.ParameterValue{} - for _, parameterValue := range updateResponse.ParameterValues { - valueByName[parameterValue.Name] = parameterValue - } - for _, parameterSchema := range parameterSchemas { - _, ok := valueByName[parameterSchema.Name] - if !ok { - p.failActiveJobf("%s: %s", missingParameterErrorText, parameterSchema.Name) - return - } - } - - // Determine persistent resources - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Detecting persistent resources", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - startResources, err := p.runTemplateImportProvision(ctx, shutdown, provisioner, job, updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ - CoderUrl: job.GetTemplateImport().Metadata.CoderUrl, - WorkspaceTransition: sdkproto.WorkspaceTransition_START, - }) - if err != nil { - p.failActiveJobf("template import provision for start: %s", err) - return - } - - // Determine ephemeral resources. - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Detecting ephemeral resources", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - stopResources, err := p.runTemplateImportProvision(ctx, shutdown, provisioner, job, updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ - CoderUrl: job.GetTemplateImport().Metadata.CoderUrl, - WorkspaceTransition: sdkproto.WorkspaceTransition_STOP, - }) - if err != nil { - p.failActiveJobf("template import provision for stop: %s", err) - return - } - - p.completeJob(&proto.CompletedJob{ - JobId: job.JobId, - Type: &proto.CompletedJob_TemplateImport_{ - TemplateImport: &proto.CompletedJob_TemplateImport{ - StartResources: startResources, - StopResources: stopResources, - }, - }, - }) +func retryable(err error) bool { + return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) || + // annoyingly, dRPC sometimes returns context.Canceled if the transport was closed, even if the context for + // the RPC *is not canceled*. Retrying is fine if the RPC context is not canceled. + xerrors.Is(err, context.Canceled) } -// Parses parameter schemas from source. -func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) ([]*sdkproto.ParameterSchema, error) { - client, ok := p.client() - if !ok { - return nil, xerrors.New("client disconnected") - } - stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{ - Directory: p.opts.WorkDirectory, - }) - if err != nil { - return nil, xerrors.Errorf("parse source: %w", err) - } - defer stream.Close() - for { - msg, err := stream.Recv() - if err != nil { - return nil, xerrors.Errorf("recv parse source: %w", err) +// clientDoWithRetries runs the function f with a client, and retries with backoff until either the error returned +// is not retryable() or the context expires. +func (p *Server) clientDoWithRetries( + ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) ( + any, error) { + for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); { + client, ok := p.client() + if !ok { + continue } - switch msgType := msg.Type.(type) { - case *sdkproto.Parse_Response_Log: - p.opts.Logger.Debug(context.Background(), "parse job logged", - slog.F("level", msgType.Log.Level), - slog.F("output", msgType.Log.Output), - ) - - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER, - Level: msgType.Log.Level, - CreatedAt: time.Now().UTC().UnixMilli(), - Output: msgType.Log.Output, - Stage: "Parse parameters", - }}, - }) - if err != nil { - return nil, xerrors.Errorf("update job: %w", err) - } - case *sdkproto.Parse_Response_Complete: - p.opts.Logger.Info(context.Background(), "parse complete", - slog.F("parameter_schemas", msgType.Complete.ParameterSchemas)) - - return msgType.Complete.ParameterSchemas, nil - default: - return nil, xerrors.Errorf("invalid message type %q received from provisioner", - reflect.TypeOf(msg.Type).String()) + resp, err := f(ctx, client) + if retryable(err) { + continue } + return resp, err } + return nil, ctx.Err() } -// Performs a dry-run provision when importing a template. -// This is used to detect resources that would be provisioned -// for a workspace in various states. -func (p *Server) runTemplateImportProvision(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob, values []*sdkproto.ParameterValue, metadata *sdkproto.Provision_Metadata) ([]*sdkproto.Resource, error) { - var stage string - switch metadata.WorkspaceTransition { - case sdkproto.WorkspaceTransition_START: - stage = "Detecting persistent resources" - case sdkproto.WorkspaceTransition_STOP: - stage = "Detecting ephemeral resources" - } - stream, err := provisioner.Provision(ctx) - if err != nil { - return nil, xerrors.Errorf("provision: %w", err) - } - defer stream.Close() - go func() { - select { - case <-ctx.Done(): - return - case <-shutdown.Done(): - _ = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Cancel{ - Cancel: &sdkproto.Provision_Cancel{}, - }, - }) - } - }() - err = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Start{ - Start: &sdkproto.Provision_Start{ - Directory: p.opts.WorkDirectory, - ParameterValues: values, - DryRun: true, - Metadata: metadata, - }, - }, +func (p *Server) updateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + out, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { + return client.UpdateJob(ctx, in) }) if err != nil { - return nil, xerrors.Errorf("start provision: %w", err) - } - - for { - msg, err := stream.Recv() - if err != nil { - return nil, xerrors.Errorf("recv import provision: %w", err) - } - switch msgType := msg.Type.(type) { - case *sdkproto.Provision_Response_Log: - p.opts.Logger.Debug(context.Background(), "template import provision job logged", - slog.F("level", msgType.Log.Level), - slog.F("output", msgType.Log.Output), - ) - client, ok := p.client() - if !ok { - continue - } - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER, - Level: msgType.Log.Level, - CreatedAt: time.Now().UTC().UnixMilli(), - Output: msgType.Log.Output, - Stage: stage, - }}, - }) - if err != nil { - return nil, xerrors.Errorf("send job update: %w", err) - } - case *sdkproto.Provision_Response_Complete: - if msgType.Complete.Error != "" { - p.opts.Logger.Info(context.Background(), "dry-run provision failure", - slog.F("error", msgType.Complete.Error), - ) - - return nil, xerrors.New(msgType.Complete.Error) - } - - p.opts.Logger.Info(context.Background(), "parse dry-run provision successful", - slog.F("resource_count", len(msgType.Complete.Resources)), - slog.F("resources", msgType.Complete.Resources), - slog.F("state_length", len(msgType.Complete.State)), - ) - - return msgType.Complete.Resources, nil - default: - return nil, xerrors.Errorf("invalid message type %q received from provisioner", - reflect.TypeOf(msg.Type).String()) - } - } -} - -func (p *Server) runTemplateDryRun(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { - // Ensure all metadata fields are set as they are all optional for dry-run. - metadata := job.GetTemplateDryRun().GetMetadata() - metadata.WorkspaceTransition = sdkproto.WorkspaceTransition_START - if metadata.CoderUrl == "" { - metadata.CoderUrl = "http://localhost:3000" - } - if metadata.WorkspaceName == "" { - metadata.WorkspaceName = "dryrun" - } - metadata.WorkspaceOwner = job.UserName - if metadata.WorkspaceOwner == "" { - metadata.WorkspaceOwner = "dryrunner" - } - if metadata.WorkspaceId == "" { - id, err := uuid.NewRandom() - if err != nil { - p.failActiveJobf("generate random ID: %s", err) - return - } - metadata.WorkspaceId = id.String() - } - if metadata.WorkspaceOwnerId == "" { - id, err := uuid.NewRandom() - if err != nil { - p.failActiveJobf("generate random ID: %s", err) - return - } - metadata.WorkspaceOwnerId = id.String() - } - - // Run the template import provision task since it's already a dry run. - resources, err := p.runTemplateImportProvision(ctx, - shutdown, - provisioner, - job, - job.GetTemplateDryRun().GetParameterValues(), - metadata, - ) - if err != nil { - p.failActiveJobf("run dry-run provision job: %s", err) - return + return nil, err } - - p.completeJob(&proto.CompletedJob{ - JobId: job.JobId, - Type: &proto.CompletedJob_TemplateDryRun_{ - TemplateDryRun: &proto.CompletedJob_TemplateDryRun{ - Resources: resources, - }, - }, - }) + // nolint: forcetypeassert + return out.(*proto.UpdateJobResponse), nil } -func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { - var stage string - switch job.GetWorkspaceBuild().Metadata.WorkspaceTransition { - case sdkproto.WorkspaceTransition_START: - stage = "Starting workspace" - case sdkproto.WorkspaceTransition_STOP: - stage = "Stopping workspace" - case sdkproto.WorkspaceTransition_DESTROY: - stage = "Destroying workspace" - } - - client, ok := p.client() - if !ok { - p.failActiveJobf("client disconnected") - return - } - _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: stage, - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - - stream, err := provisioner.Provision(ctx) - if err != nil { - p.failActiveJobf("provision: %s", err) - return - } - defer stream.Close() - go func() { - select { - case <-ctx.Done(): - return - case <-shutdown.Done(): - _ = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Cancel{ - Cancel: &sdkproto.Provision_Cancel{}, - }, - }) - } - }() - err = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Start{ - Start: &sdkproto.Provision_Start{ - Directory: p.opts.WorkDirectory, - ParameterValues: job.GetWorkspaceBuild().ParameterValues, - Metadata: job.GetWorkspaceBuild().Metadata, - State: job.GetWorkspaceBuild().State, - }, - }, +func (p *Server) failJob(ctx context.Context, in *proto.FailedJob) error { + _, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { + return client.FailJob(ctx, in) }) - if err != nil { - p.failActiveJobf("start provision: %s", err) - return - } - - for { - msg, err := stream.Recv() - if err != nil { - p.failActiveJobf("recv workspace provision: %s", err) - return - } - switch msgType := msg.Type.(type) { - case *sdkproto.Provision_Response_Log: - p.opts.Logger.Debug(context.Background(), "workspace provision job logged", - slog.F("level", msgType.Log.Level), - slog.F("output", msgType.Log.Output), - slog.F("workspace_build_id", job.GetWorkspaceBuild().WorkspaceBuildId), - ) - - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER, - Level: msgType.Log.Level, - CreatedAt: time.Now().UTC().UnixMilli(), - Output: msgType.Log.Output, - Stage: stage, - }}, - }) - if err != nil { - p.failActiveJobf("send job update: %s", err) - return - } - case *sdkproto.Provision_Response_Complete: - if msgType.Complete.Error != "" { - p.opts.Logger.Info(context.Background(), "provision failed; updating state", - slog.F("state_length", len(msgType.Complete.State)), - ) - - p.failActiveJob(&proto.FailedJob{ - Error: msgType.Complete.Error, - Type: &proto.FailedJob_WorkspaceBuild_{ - WorkspaceBuild: &proto.FailedJob_WorkspaceBuild{ - State: msgType.Complete.State, - }, - }, - }) - return - } - - p.completeJob(&proto.CompletedJob{ - JobId: job.JobId, - Type: &proto.CompletedJob_WorkspaceBuild_{ - WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ - State: msgType.Complete.State, - Resources: msgType.Complete.Resources, - }, - }, - }) - p.opts.Logger.Info(context.Background(), "provision successful; marked job as complete", - slog.F("resource_count", len(msgType.Complete.Resources)), - slog.F("resources", msgType.Complete.Resources), - slog.F("state_length", len(msgType.Complete.State)), - ) - // Stop looping! - return - default: - p.failActiveJobf("invalid message type %T received from provisioner", msg.Type) - return - } - } -} - -func (p *Server) completeJob(job *proto.CompletedJob) { - for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); { - client, ok := p.client() - if !ok { - continue - } - // Complete job may need to be async if we disconnected... - // When we reconnect we can flush any of these cached values. - _, err := client.CompleteJob(p.closeContext, job) - if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) { - continue - } - if err != nil { - p.opts.Logger.Warn(p.closeContext, "failed to complete job", slog.Error(err)) - p.failActiveJobf(err.Error()) - return - } - break - } + return err } -func (p *Server) failActiveJobf(format string, args ...interface{}) { - p.failActiveJob(&proto.FailedJob{ - Error: fmt.Sprintf(format, args...), +func (p *Server) completeJob(ctx context.Context, in *proto.CompletedJob) error { + _, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { + return client.CompleteJob(ctx, in) }) -} - -func (p *Server) failActiveJob(failedJob *proto.FailedJob) { - p.jobMutex.Lock() - defer p.jobMutex.Unlock() - if !p.isRunningJob() { - return - } - if p.jobFailed.Load() { - p.opts.Logger.Debug(context.Background(), "job has already been marked as failed", slog.F("error_messsage", failedJob.Error)) - return - } - p.jobFailed.Store(true) - p.jobCancel() - p.opts.Logger.Info(context.Background(), "failing running job", - slog.F("error_message", failedJob.Error), - slog.F("job_id", p.jobID), - ) - failedJob.JobId = p.jobID - for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); { - client, ok := p.client() - if !ok { - continue - } - _, err := client.FailJob(p.closeContext, failedJob) - if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) { - continue - } - if err != nil { - if p.isClosed() { - return - } - p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err)) - return - } - p.opts.Logger.Debug(context.Background(), "marked running job as failed") - return - } + return err } // isClosed returns whether the API is closed or not. @@ -1029,18 +315,23 @@ func (p *Server) isShutdown() bool { // Shutdown triggers a graceful exit of each registered provisioner. // It exits when an active job stops. func (p *Server) Shutdown(ctx context.Context) error { - p.shutdownMutex.Lock() - defer p.shutdownMutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if !p.isRunningJob() { return nil } p.opts.Logger.Info(ctx, "attempting graceful shutdown") close(p.shutdown) + if p.activeJob == nil { + return nil + } + // wait for active job + p.activeJob.cancel() select { case <-ctx.Done(): p.opts.Logger.Warn(ctx, "graceful shutdown failed", slog.Error(ctx.Err())) return ctx.Err() - case <-p.jobRunning: + case <-p.activeJob.isDone(): p.opts.Logger.Info(ctx, "gracefully shutdown") return nil } @@ -1053,8 +344,8 @@ func (p *Server) Close() error { // closeWithError closes the provisioner; subsequent reads/writes will return the error err. func (p *Server) closeWithError(err error) error { - p.closeMutex.Lock() - defer p.closeMutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if p.isClosed() { return p.closeError } @@ -1064,10 +355,18 @@ func (p *Server) closeWithError(err error) error { if err != nil { errMsg = err.Error() } - p.failActiveJobf(errMsg) - p.jobRunningMutex.Lock() - <-p.jobRunning - p.jobRunningMutex.Unlock() + if p.activeJob != nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + failErr := p.activeJob.fail(ctx, &proto.FailedJob{Error: errMsg}) + if failErr != nil { + p.activeJob.forceStop() + } + if err == nil { + err = failErr + } + } + p.closeCancel() p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 0031d662311f4..d92dc271a5c2e 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -4,6 +4,7 @@ import ( "archive/tar" "bytes" "context" + "fmt" "io" "os" "path/filepath" @@ -12,6 +13,7 @@ import ( "time" "github.com/hashicorp/yamux" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" @@ -32,6 +34,18 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } +func closedWithin(c chan struct{}, d time.Duration) func() bool { + return func() bool { + pop := time.After(d) + select { + case <-c: + return true + case <-pop: + return false + } + } +} + func TestProvisionerd(t *testing.T) { t.Parallel() @@ -54,7 +68,7 @@ func TestProvisionerd(t *testing.T) { defer close(completeChan) return nil, xerrors.New("an error") }, provisionerd.Provisioners{}) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -77,7 +91,7 @@ func TestProvisionerd(t *testing.T) { updateJob: noopUpdateJob, }), nil }, provisionerd.Provisioners{}) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -122,7 +136,7 @@ func TestProvisionerd(t *testing.T) { }), }) closerMutex.Unlock() - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -160,7 +174,7 @@ func TestProvisionerd(t *testing.T) { }, provisionerd.Provisioners{ "someprovisioner": createProvisionerClient(t, provisionerTestServer{}), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -203,7 +217,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -308,7 +322,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didLog.Load()) require.True(t, didComplete.Load()) require.True(t, didDryRun.Load()) @@ -388,7 +402,7 @@ func TestProvisionerd(t *testing.T) { }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didLog.Load()) require.True(t, didComplete.Load()) require.NoError(t, closer.Close()) @@ -459,7 +473,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didLog.Load()) require.True(t, didComplete.Load()) require.NoError(t, closer.Close()) @@ -514,7 +528,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didFail.Load()) require.NoError(t, closer.Close()) }) @@ -587,10 +601,10 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-updateChan + require.Condition(t, closedWithin(updateChan, 5*time.Second)) err := server.Shutdown(context.Background()) require.NoError(t, err) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, server.Close()) }) @@ -617,15 +631,21 @@ func TestProvisionerd(t *testing.T) { }, nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + resp := &proto.UpdateJobResponse{} if len(update.Logs) > 0 && update.Logs[0].Source == proto.LogSource_PROVISIONER { // Close on a log so we know when the job is in progress! updated.Do(func() { close(updateChan) }) } - return &proto.UpdateJobResponse{ - Canceled: true, - }, nil + // start returning Canceled once we've gotten at least one log. + select { + case <-updateChan: + resp.Canceled = true + default: + // pass + } + return resp, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { completed.Do(func() { @@ -664,8 +684,8 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-updateChan - <-completeChan + require.Condition(t, closedWithin(updateChan, 5*time.Second)) + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, server.Close()) }) @@ -703,6 +723,7 @@ func TestProvisionerd(t *testing.T) { return &proto.UpdateJobResponse{}, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { + assert.Equal(t, job.JobId, "test") if second.Load() { completeOnce.Do(func() { close(completeChan) }) return &proto.Empty{}, nil @@ -736,7 +757,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, server.Close()) }) @@ -808,7 +829,96 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) + require.NoError(t, server.Close()) + }) + + t.Run("UpdatesBeforeComplete", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + m := sync.Mutex{} + var ops []string + completeChan := make(chan struct{}) + completeOnce := sync.Once{} + + server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "AcquiredJob called.") + if len(ops) > 0 { + return &proto.AcquiredJob{}, nil + } + ops = append(ops, "AcquireJob") + + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + TemplateSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ + Metadata: &sdkproto.Provision_Metadata{}, + }, + }, + }, nil + }, + updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "UpdateJob called.") + ops = append(ops, "UpdateJob") + for _, log := range update.Logs { + ops = append(ops, fmt.Sprintf("Log: %s | %s", log.Stage, log.Output)) + } + return &proto.UpdateJobResponse{}, nil + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "CompleteJob called.") + ops = append(ops, "CompleteJob") + completeOnce.Do(func() { close(completeChan) }) + return &proto.Empty{}, nil + }, + failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "FailJob called.") + ops = append(ops, "FailJob") + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { + err := stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_DEBUG, + Output: "wow", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Complete{ + Complete: &sdkproto.Provision_Complete{}, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + require.Condition(t, closedWithin(completeChan, 5*time.Second)) + m.Lock() + defer m.Unlock() + require.Equal(t, ops[len(ops)-1], "CompleteJob") + require.Contains(t, ops[0:len(ops)-1], "Log: Cleaning Up | ") require.NoError(t, server.Close()) }) } diff --git a/provisionerd/runner.go b/provisionerd/runner.go new file mode 100644 index 0000000000000..af89b8c8f293e --- /dev/null +++ b/provisionerd/runner.go @@ -0,0 +1,872 @@ +package provisionerd + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "reflect" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/spf13/afero" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +type jobRunner interface { + start() + cancel() + isDone() <-chan any + fail(ctx context.Context, f *proto.FailedJob) error + forceStop() +} + +type runner struct { + job *proto.AcquiredJob + sender messageSender + logger slog.Logger + filesystem afero.Fs + workDirectory string + provisioner sdkproto.DRPCProvisionerClient + updateInterval time.Duration + forceCancelInterval time.Duration + + // mutex controls access to all the following variables. + mutex *sync.Mutex + // used to wait for the failedJob or completedJob to be populated + cond *sync.Cond + // closed when the job ends. + done chan any + failedJob *proto.FailedJob + completedJob *proto.CompletedJob + // active as long as we are not canceled + gracefulContext context.Context + cancelFunc context.CancelFunc + // active as long as we haven't been force stopped + forceStopContext context.Context + forceStopFunc context.CancelFunc + // setting this false signals that no more messages about this job should be sent. Usually this means that a + // terminal message like FailedJob or CompletedJob has been sent, but if we are force canceled, we may set this + // false and not send one. + okToSend bool +} + +type messageSender interface { + updateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) + failJob(ctx context.Context, in *proto.FailedJob) error + completeJob(ctx context.Context, in *proto.CompletedJob) error +} + +func newRunner( + job *proto.AcquiredJob, + sender messageSender, + logger slog.Logger, + filesystem afero.Fs, + workDirectory string, + provisioner sdkproto.DRPCProvisionerClient, + updateInterval time.Duration, + forceCancelInterval time.Duration) jobRunner { + m := new(sync.Mutex) + + // we need to create our contexts here in case a call to cancel() comes immediately. + logCtx := slog.With(context.Background(), slog.F("job_id", job.JobId)) + forceStopContext, forceStopFunc := context.WithCancel(logCtx) + gracefulContext, cancelFunc := context.WithCancel(forceStopContext) + + return &runner{ + job: job, + sender: sender, + logger: logger, + filesystem: filesystem, + workDirectory: workDirectory, + provisioner: provisioner, + updateInterval: updateInterval, + forceCancelInterval: forceCancelInterval, + mutex: m, + cond: sync.NewCond(m), + done: make(chan any), + okToSend: true, + forceStopContext: forceStopContext, + forceStopFunc: forceStopFunc, + gracefulContext: gracefulContext, + cancelFunc: cancelFunc, + } +} + +func (r *runner) start() { + r.mutex.Lock() + defer r.mutex.Unlock() + defer r.forceStopFunc() + + // the idea here is to run two goroutines to work on the job: run and heartbeat, then use the `r.cond` to wait until + // the job is either complete or failed. This function then sends the complete or failed message --- the exception + // to this is if something calls fail() on the runner; either something external, like the server getting closed, + // or the heartbeat goroutine timing out after attempting to gracefully cancel. If something calls fail(), then + // the failure is sent on that goroutine on the context passed into fail(), and it marks okToSend false to signal + // us here that this function should not also send a terminal message. + go r.run() + go r.heartbeat() + for r.failedJob == nil && r.completedJob == nil { + r.cond.Wait() + } + if !r.okToSend { + // nothing else to do. + return + } + if r.failedJob != nil { + r.logger.Debug(r.forceStopContext, "sending FailedJob") + err := r.sender.failJob(r.forceStopContext, r.failedJob) + if err != nil { + r.logger.Error(r.forceStopContext, "send FailJob", slog.Error(err)) + } + r.logger.Info(r.forceStopContext, "sent FailedJob") + } else { + r.logger.Debug(r.forceStopContext, "sending CompletedJob") + err := r.sender.completeJob(r.forceStopContext, r.completedJob) + if err != nil { + r.logger.Error(r.forceStopContext, "send CompletedJob", slog.Error(err)) + } + r.logger.Info(r.forceStopContext, "sent CompletedJob") + } + close(r.done) + r.okToSend = false +} + +// cancel initiates a cancel on the job, but allows it to keep running to do so gracefully. Read from isDone() to +// be notified when the job completes. +func (r *runner) cancel() { + r.cancelFunc() +} + +func (r *runner) isDone() <-chan any { + return r.done +} + +// fail immediately halts updates and, if the job is not complete sends FailJob to the coder server. Running goroutines +// are canceled but complete asynchronously (although they are prevented from further updating the job to the coder +// server). The provided context sets how long to keep trying to send the FailJob. +func (r *runner) fail(ctx context.Context, f *proto.FailedJob) error { + f.JobId = r.job.JobId + r.mutex.Lock() + defer r.mutex.Unlock() + if !r.okToSend { + return nil // already done + } + r.cancel() + if r.failedJob == nil { + r.failedJob = f + r.cond.Signal() + } + // here we keep the original failed reason if there already was one, but we hadn't yet sent it. It is likely more + // informative of the job failing due to some problem running it, whereas this function is used to externally + // force a fail. + err := r.sender.failJob(ctx, r.failedJob) + r.okToSend = false + r.forceStopFunc() + close(r.done) + return err +} + +// setComplete is an internal function to set the job to completed. This does not send the completedJob. +func (r *runner) setComplete(c *proto.CompletedJob) { + r.mutex.Lock() + defer r.mutex.Unlock() + if r.completedJob == nil { + r.completedJob = c + r.cond.Signal() + } +} + +// setFail is an internal function to set the job to failed. This does not send the failedJob. +func (r *runner) setFail(f *proto.FailedJob) { + r.mutex.Lock() + defer r.mutex.Unlock() + if r.failedJob == nil { + f.JobId = r.job.GetJobId() + r.failedJob = f + r.cond.Signal() + } +} + +// forceStop signals all goroutines to stop and prevents any further API calls back to coder server for this job +func (r *runner) forceStop() { + r.mutex.Lock() + defer r.mutex.Unlock() + r.forceStopFunc() + if !r.okToSend { + return + } + r.okToSend = false + close(r.done) + // doesn't matter what we put here, since it won't get sent! Just need something to satisfy the condition in + // start() + r.failedJob = &proto.FailedJob{} + r.cond.Signal() +} + +func (r *runner) update(ctx context.Context, u *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + if !r.okToSend { + return nil, xerrors.New("update skipped; job complete or failed") + } + return r.sender.updateJob(ctx, u) +} + +func (r *runner) run() { + // push the fail/succeed write onto the defer stack before the cleanup, so that cleanup happens before this. + // Failures during this function should write to the _local_ failedJob variable, then return. + var failedJob *proto.FailedJob + var completedJob *proto.CompletedJob + defer func() { + if failedJob != nil { + r.setFail(failedJob) + return + } + r.setComplete(completedJob) + }() + + defer func() { + r.logCleanup(r.forceStopContext) + + // Cleanup the work directory after execution. + for attempt := 0; attempt < 5; attempt++ { + err := r.filesystem.RemoveAll(r.workDirectory) + if err != nil { + // On Windows, open files cannot be removed. + // When the provisioner daemon is shutting down, + // it may take a few milliseconds for processes to exit. + // See: https://github.com/golang/go/issues/50510 + r.logger.Debug(r.forceStopContext, "failed to clean work directory; trying again", slog.Error(err)) + time.Sleep(250 * time.Millisecond) + continue + } + r.logger.Debug(r.forceStopContext, "cleaned up work directory", slog.Error(err)) + break + } + }() + + err := r.filesystem.MkdirAll(r.workDirectory, 0700) + if err != nil { + failedJob = r.failedJobf("create work directory %q: %s", r.workDirectory, err) + return + } + + _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Setting up", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + failedJob = r.failedJobf("write log: %s", err) + return + } + + r.logger.Info(r.forceStopContext, "unpacking template source archive", + slog.F("size_bytes", len(r.job.TemplateSourceArchive))) + reader := tar.NewReader(bytes.NewBuffer(r.job.TemplateSourceArchive)) + for { + header, err := reader.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + failedJob = r.failedJobf("read template source archive: %s", err) + return + } + // #nosec + headerPath := filepath.Join(r.workDirectory, header.Name) + if !strings.HasPrefix(headerPath, filepath.Clean(r.workDirectory)) { + failedJob = r.failedJobf("tar attempts to target relative upper directory") + return + } + mode := header.FileInfo().Mode() + if mode == 0 { + mode = 0600 + } + switch header.Typeflag { + case tar.TypeDir: + err = r.filesystem.MkdirAll(headerPath, mode) + if err != nil { + failedJob = r.failedJobf("mkdir %q: %s", headerPath, err) + return + } + r.logger.Debug(context.Background(), "extracted directory", slog.F("path", headerPath)) + case tar.TypeReg: + file, err := r.filesystem.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode) + if err != nil { + failedJob = r.failedJobf("create file %q (mode %s): %s", headerPath, mode, err) + return + } + // Max file size of 10MiB. + size, err := io.CopyN(file, reader, 10<<20) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + _ = file.Close() + failedJob = r.failedJobf("copy file %q: %s", headerPath, err) + return + } + err = file.Close() + if err != nil { + failedJob = r.failedJobf("close file %q: %s", headerPath, err) + return + } + r.logger.Debug(context.Background(), "extracted file", + slog.F("size_bytes", size), + slog.F("path", headerPath), + slog.F("mode", mode), + ) + } + } + + switch jobType := r.job.Type.(type) { + case *proto.AcquiredJob_TemplateImport_: + r.logger.Debug(context.Background(), "acquired job is template import") + + failedJob = r.runReadmeParse() + if failedJob == nil { + completedJob, failedJob = r.runTemplateImport() + } + case *proto.AcquiredJob_TemplateDryRun_: + r.logger.Debug(context.Background(), "acquired job is template dry-run", + slog.F("workspace_name", jobType.TemplateDryRun.Metadata.WorkspaceName), + slog.F("parameters", jobType.TemplateDryRun.ParameterValues), + ) + completedJob, failedJob = r.runTemplateDryRun() + case *proto.AcquiredJob_WorkspaceBuild_: + r.logger.Debug(context.Background(), "acquired job is workspace provision", + slog.F("workspace_name", jobType.WorkspaceBuild.WorkspaceName), + slog.F("state_length", len(jobType.WorkspaceBuild.State)), + slog.F("parameters", jobType.WorkspaceBuild.ParameterValues), + ) + + completedJob, failedJob = r.runWorkspaceBuild() + default: + failedJob = r.failedJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", + reflect.TypeOf(r.job.Type).String()) + } +} + +func (r *runner) heartbeat() { + ticker := time.NewTicker(r.updateInterval) + defer ticker.Stop() + for { + select { + case <-r.gracefulContext.Done(): + return + case <-ticker.C: + } + + resp, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + }) + if err != nil { + err = r.fail(r.forceStopContext, r.failedJobf("send periodic update: %s", err)) + if err != nil { + r.logger.Error(r.forceStopContext, "failed to call FailJob", slog.Error(err)) + } + return + } + if !resp.Canceled { + continue + } + r.logger.Info(r.forceStopContext, "attempting graceful cancelation") + r.cancel() + // Hard-cancel the job after a minute of pending cancelation. + timer := time.NewTimer(r.forceCancelInterval) + select { + case <-timer.C: + r.logger.Warn(r.forceStopContext, "cancel timed out") + err := r.fail(r.forceStopContext, r.failedJobf("cancel timed out")) + if err != nil { + r.logger.Warn(r.forceStopContext, "failed to call FailJob", slog.Error(err)) + } + return + case <-r.isDone(): + timer.Stop() + return + case <-r.forceStopContext.Done(): + timer.Stop() + return + } + } +} + +func (r *runner) logCleanup(ctx context.Context) { + _, err := r.update(ctx, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Cleaning Up", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + r.logger.Warn(ctx, "failed to log cleanup") + return + } +} + +// ReadmeFile is the location we look for to extract documentation from template +// versions. +const ReadmeFile = "README.md" + +func (r *runner) runReadmeParse() *proto.FailedJob { + fi, err := afero.ReadFile(r.filesystem, path.Join(r.workDirectory, ReadmeFile)) + if err != nil { + _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_DEBUG, + Stage: "No README.md provided", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return r.failedJobf("write log: %s", err) + } + + return nil + } + + _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Adding README.md...", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + Readme: fi, + }) + if err != nil { + return r.failedJobf("write log: %s", err) + } + return nil +} + +func (r *runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { + // Parse parameters and update the job with the parameter specs + _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Parsing template parameters", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + parameterSchemas, err := r.runTemplateImportParse() + if err != nil { + return nil, r.failedJobf("run parse: %s", err) + } + updateResponse, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + ParameterSchemas: parameterSchemas, + }) + if err != nil { + return nil, r.failedJobf("update job: %s", err) + } + + valueByName := map[string]*sdkproto.ParameterValue{} + for _, parameterValue := range updateResponse.ParameterValues { + valueByName[parameterValue.Name] = parameterValue + } + for _, parameterSchema := range parameterSchemas { + _, ok := valueByName[parameterSchema.Name] + if !ok { + return nil, r.failedJobf("%s: %s", missingParameterErrorText, parameterSchema.Name) + } + } + + // Determine persistent resources + _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Detecting persistent resources", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + startResources, err := r.runTemplateImportProvision(updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ + CoderUrl: r.job.GetTemplateImport().Metadata.CoderUrl, + WorkspaceTransition: sdkproto.WorkspaceTransition_START, + }) + if err != nil { + return nil, r.failedJobf("template import provision for start: %s", err) + } + + // Determine ephemeral resources. + _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Detecting ephemeral resources", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + stopResources, err := r.runTemplateImportProvision(updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ + CoderUrl: r.job.GetTemplateImport().Metadata.CoderUrl, + WorkspaceTransition: sdkproto.WorkspaceTransition_STOP, + }) + if err != nil { + return nil, r.failedJobf("template import provision for stop: %s", err) + } + + return &proto.CompletedJob{ + JobId: r.job.JobId, + Type: &proto.CompletedJob_TemplateImport_{ + TemplateImport: &proto.CompletedJob_TemplateImport{ + StartResources: startResources, + StopResources: stopResources, + }, + }, + }, nil +} + +// Parses parameter schemas from source. +func (r *runner) runTemplateImportParse() ([]*sdkproto.ParameterSchema, error) { + stream, err := r.provisioner.Parse(r.forceStopContext, &sdkproto.Parse_Request{ + Directory: r.workDirectory, + }) + if err != nil { + return nil, xerrors.Errorf("parse source: %w", err) + } + defer stream.Close() + for { + msg, err := stream.Recv() + if err != nil { + return nil, xerrors.Errorf("recv parse source: %w", err) + } + switch msgType := msg.Type.(type) { + case *sdkproto.Parse_Response_Log: + r.logger.Debug(context.Background(), "parse job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + ) + + _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + Stage: "Parse parameters", + }}, + }) + if err != nil { + return nil, xerrors.Errorf("update job: %w", err) + } + case *sdkproto.Parse_Response_Complete: + r.logger.Info(context.Background(), "parse complete", + slog.F("parameter_schemas", msgType.Complete.ParameterSchemas)) + + return msgType.Complete.ParameterSchemas, nil + default: + return nil, xerrors.Errorf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String()) + } + } +} + +// Performs a dry-run provision when importing a template. +// This is used to detect resources that would be provisioned +// for a workspace in various states. +func (r *runner) runTemplateImportProvision(values []*sdkproto.ParameterValue, metadata *sdkproto.Provision_Metadata) ([]*sdkproto.Resource, error) { + var stage string + switch metadata.WorkspaceTransition { + case sdkproto.WorkspaceTransition_START: + stage = "Detecting persistent resources" + case sdkproto.WorkspaceTransition_STOP: + stage = "Detecting ephemeral resources" + } + // use the forceStopContext so that if we attempt to gracefully cancel, the stream will still be available for us + // to send the cancel to the provisioner + stream, err := r.provisioner.Provision(r.forceStopContext) + if err != nil { + return nil, xerrors.Errorf("provision: %w", err) + } + defer stream.Close() + go func() { + select { + case <-r.forceStopContext.Done(): + return + case <-r.gracefulContext.Done(): + _ = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Cancel{ + Cancel: &sdkproto.Provision_Cancel{}, + }, + }) + } + }() + err = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Start{ + Start: &sdkproto.Provision_Start{ + Directory: r.workDirectory, + ParameterValues: values, + DryRun: true, + Metadata: metadata, + }, + }, + }) + if err != nil { + return nil, xerrors.Errorf("start provision: %w", err) + } + + for { + msg, err := stream.Recv() + if err != nil { + return nil, xerrors.Errorf("recv import provision: %w", err) + } + switch msgType := msg.Type.(type) { + case *sdkproto.Provision_Response_Log: + r.logger.Debug(context.Background(), "template import provision job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + ) + _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + Stage: stage, + }}, + }) + if err != nil { + return nil, xerrors.Errorf("send job update: %w", err) + } + case *sdkproto.Provision_Response_Complete: + if msgType.Complete.Error != "" { + r.logger.Info(context.Background(), "dry-run provision failure", + slog.F("error", msgType.Complete.Error), + ) + + return nil, xerrors.New(msgType.Complete.Error) + } + + r.logger.Info(context.Background(), "parse dry-run provision successful", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + + return msgType.Complete.Resources, nil + default: + return nil, xerrors.Errorf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String()) + } + } +} + +func (r *runner) runTemplateDryRun() ( + *proto.CompletedJob, *proto.FailedJob) { + // Ensure all metadata fields are set as they are all optional for dry-run. + metadata := r.job.GetTemplateDryRun().GetMetadata() + metadata.WorkspaceTransition = sdkproto.WorkspaceTransition_START + if metadata.CoderUrl == "" { + metadata.CoderUrl = "http://localhost:3000" + } + if metadata.WorkspaceName == "" { + metadata.WorkspaceName = "dryrun" + } + metadata.WorkspaceOwner = r.job.UserName + if metadata.WorkspaceOwner == "" { + metadata.WorkspaceOwner = "dryrunner" + } + if metadata.WorkspaceId == "" { + id, err := uuid.NewRandom() + if err != nil { + return nil, r.failedJobf("generate random ID: %s", err) + } + metadata.WorkspaceId = id.String() + } + if metadata.WorkspaceOwnerId == "" { + id, err := uuid.NewRandom() + if err != nil { + return nil, r.failedJobf("generate random ID: %s", err) + } + metadata.WorkspaceOwnerId = id.String() + } + + // Run the template import provision task since it's already a dry run. + resources, err := r.runTemplateImportProvision( + r.job.GetTemplateDryRun().GetParameterValues(), + metadata, + ) + if err != nil { + return nil, r.failedJobf("run dry-run provision job: %s", err) + } + + return &proto.CompletedJob{ + JobId: r.job.JobId, + Type: &proto.CompletedJob_TemplateDryRun_{ + TemplateDryRun: &proto.CompletedJob_TemplateDryRun{ + Resources: resources, + }, + }, + }, nil +} + +func (r *runner) runWorkspaceBuild() ( + *proto.CompletedJob, *proto.FailedJob) { + var stage string + switch r.job.GetWorkspaceBuild().Metadata.WorkspaceTransition { + case sdkproto.WorkspaceTransition_START: + stage = "Starting workspace" + case sdkproto.WorkspaceTransition_STOP: + stage = "Stopping workspace" + case sdkproto.WorkspaceTransition_DESTROY: + stage = "Destroying workspace" + } + + _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: stage, + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + + // use the forceStopContext so that if we attempt to gracefully cancel, the stream will still be available for us + // to send the cancel to the provisioner + stream, err := r.provisioner.Provision(r.forceStopContext) + if err != nil { + return nil, r.failedJobf("provision: %s", err) + } + defer stream.Close() + go func() { + select { + case <-r.forceStopContext.Done(): + return + case <-r.gracefulContext.Done(): + _ = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Cancel{ + Cancel: &sdkproto.Provision_Cancel{}, + }, + }) + } + }() + err = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Start{ + Start: &sdkproto.Provision_Start{ + Directory: r.workDirectory, + ParameterValues: r.job.GetWorkspaceBuild().ParameterValues, + Metadata: r.job.GetWorkspaceBuild().Metadata, + State: r.job.GetWorkspaceBuild().State, + }, + }, + }) + if err != nil { + return nil, r.failedJobf("start provision: %s", err) + } + + for { + msg, err := stream.Recv() + if err != nil { + return nil, r.failedJobf("recv workspace provision: %s", err) + } + switch msgType := msg.Type.(type) { + case *sdkproto.Provision_Response_Log: + r.logger.Debug(context.Background(), "workspace provision job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("workspace_build_id", r.job.GetWorkspaceBuild().WorkspaceBuildId), + ) + + _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + Stage: stage, + }}, + }) + if err != nil { + return nil, r.failedJobf("send job update: %s", err) + } + case *sdkproto.Provision_Response_Complete: + if msgType.Complete.Error != "" { + r.logger.Info(context.Background(), "provision failed; updating state", + slog.F("state_length", len(msgType.Complete.State)), + ) + + return nil, &proto.FailedJob{ + JobId: r.job.JobId, + Error: msgType.Complete.Error, + Type: &proto.FailedJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.FailedJob_WorkspaceBuild{ + State: msgType.Complete.State, + }, + }, + } + } + + r.logger.Debug(context.Background(), "provision complete no error") + r.logger.Info(context.Background(), "provision successful; marked job as complete", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + // Stop looping! + return &proto.CompletedJob{ + JobId: r.job.JobId, + Type: &proto.CompletedJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ + State: msgType.Complete.State, + Resources: msgType.Complete.Resources, + }, + }, + }, nil + default: + return nil, r.failedJobf("invalid message type %T received from provisioner", msg.Type) + } + } +} + +func (r *runner) failedJobf(format string, args ...interface{}) *proto.FailedJob { + return &proto.FailedJob{ + JobId: r.job.JobId, + Error: fmt.Sprintf(format, args...), + } +} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index e465a35d06d24..178998b5a21f9 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -89,7 +89,7 @@ func (p *PTY) ExpectMatch(str string) string { case <-timer.C: } _ = p.Close() - p.t.Errorf("match exceeded deadline: wanted %q; got %q", str, buffer.String()) + p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) }() for { var r rune From e7a7119622d0ac43b28864eec4e80680fd604e95 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 30 Jun 2022 11:41:25 -0700 Subject: [PATCH 2/5] Move runner into package Signed-off-by: Spike Curtis --- provisionerd/provisionerd.go | 13 +++++-------- provisionerd/provisionerd_test.go | 6 ++++-- provisionerd/{ => runner}/runner.go | 8 ++++++-- 3 files changed, 15 insertions(+), 12 deletions(-) rename provisionerd/{ => runner}/runner.go (99%) diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 877f03b3e5bec..eefc346594ce2 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -9,26 +9,23 @@ import ( "sync" "time" + "cdr.dev/slog" "github.com/hashicorp/yamux" "github.com/spf13/afero" "go.uber.org/atomic" "golang.org/x/xerrors" - "cdr.dev/slog" "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionerd/runner" sdkproto "github.com/coder/coder/provisionersdk/proto" "github.com/coder/retry" ) -const ( - missingParameterErrorText = "missing parameter" -) - // IsMissingParameterError returns whether the error message provided // is a missing parameter error. This can indicate to consumers that // they should check parameters. func IsMissingParameterError(err string) bool { - return strings.Contains(err, missingParameterErrorText) + return strings.Contains(err, runner.MissingParameterErrorText) } // Dialer represents the function to create a daemon client connection. @@ -90,7 +87,7 @@ type Server struct { closeCancel context.CancelFunc closeError error shutdown chan struct{} - activeJob jobRunner + activeJob runner.jobRunner } // Connect establishes a connection to coderd. @@ -236,7 +233,7 @@ func (p *Server) acquireJob(ctx context.Context) { } return } - p.activeJob = newRunner(job, p, p.opts.Logger, p.opts.Filesystem, p.opts.WorkDirectory, provisioner, + p.activeJob = runner.newRunner(job, p, p.opts.Logger, p.opts.Filesystem, p.opts.WorkDirectory, provisioner, p.opts.UpdateInterval, p.opts.ForceCancelInterval) go p.activeJob.start() } diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index d92dc271a5c2e..9b544494a1129 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/coder/coder/provisionerd/runner" + "github.com/hashicorp/yamux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -245,8 +247,8 @@ func TestProvisionerd(t *testing.T) { JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ - "test.txt": "content", - provisionerd.ReadmeFile: "# A cool template 😎\n", + "test.txt": "content", + runner.ReadmeFile: "# A cool template 😎\n", }), Type: &proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ diff --git a/provisionerd/runner.go b/provisionerd/runner/runner.go similarity index 99% rename from provisionerd/runner.go rename to provisionerd/runner/runner.go index af89b8c8f293e..11f560a19bf6e 100644 --- a/provisionerd/runner.go +++ b/provisionerd/runner/runner.go @@ -1,4 +1,4 @@ -package provisionerd +package runner import ( "archive/tar" @@ -25,6 +25,10 @@ import ( sdkproto "github.com/coder/coder/provisionersdk/proto" ) +const ( + MissingParameterErrorText = "missing parameter" +) + type jobRunner interface { start() cancel() @@ -498,7 +502,7 @@ func (r *runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { for _, parameterSchema := range parameterSchemas { _, ok := valueByName[parameterSchema.Name] if !ok { - return nil, r.failedJobf("%s: %s", missingParameterErrorText, parameterSchema.Name) + return nil, r.failedJobf("%s: %s", provisionerd.missingParameterErrorText, parameterSchema.Name) } } From 97ed447b56171dc8075dc342475c4a30cfc9b038 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 30 Jun 2022 11:57:13 -0700 Subject: [PATCH 3/5] Remove jobRunner interface Signed-off-by: Spike Curtis --- provisionerd/provisionerd.go | 24 +++---- provisionerd/runner/runner.go | 123 ++++++++++++++++------------------ 2 files changed, 70 insertions(+), 77 deletions(-) diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index eefc346594ce2..62d4efc8579e6 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -87,7 +87,7 @@ type Server struct { closeCancel context.CancelFunc closeError error shutdown chan struct{} - activeJob runner.jobRunner + activeJob *runner.Runner } // Connect establishes a connection to coderd. @@ -174,7 +174,7 @@ func (p *Server) isRunningJob() bool { return false } select { - case <-p.activeJob.isDone(): + case <-p.activeJob.Done(): return false default: return true @@ -223,7 +223,7 @@ func (p *Server) acquireJob(ctx context.Context) { provisioner, ok := p.opts.Provisioners[job.Provisioner] if !ok { - err := p.failJob(ctx, &proto.FailedJob{ + err := p.FailJob(ctx, &proto.FailedJob{ JobId: job.JobId, Error: fmt.Sprintf("no provisioner %s", job.Provisioner), }) @@ -233,9 +233,9 @@ func (p *Server) acquireJob(ctx context.Context) { } return } - p.activeJob = runner.newRunner(job, p, p.opts.Logger, p.opts.Filesystem, p.opts.WorkDirectory, provisioner, + p.activeJob = runner.NewRunner(job, p, p.opts.Logger, p.opts.Filesystem, p.opts.WorkDirectory, provisioner, p.opts.UpdateInterval, p.opts.ForceCancelInterval) - go p.activeJob.start() + go p.activeJob.Start() } func retryable(err error) bool { @@ -264,7 +264,7 @@ func (p *Server) clientDoWithRetries( return nil, ctx.Err() } -func (p *Server) updateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { +func (p *Server) UpdateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { out, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { return client.UpdateJob(ctx, in) }) @@ -275,14 +275,14 @@ func (p *Server) updateJob(ctx context.Context, in *proto.UpdateJobRequest) (*pr return out.(*proto.UpdateJobResponse), nil } -func (p *Server) failJob(ctx context.Context, in *proto.FailedJob) error { +func (p *Server) FailJob(ctx context.Context, in *proto.FailedJob) error { _, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { return client.FailJob(ctx, in) }) return err } -func (p *Server) completeJob(ctx context.Context, in *proto.CompletedJob) error { +func (p *Server) CompleteJob(ctx context.Context, in *proto.CompletedJob) error { _, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { return client.CompleteJob(ctx, in) }) @@ -323,12 +323,12 @@ func (p *Server) Shutdown(ctx context.Context) error { return nil } // wait for active job - p.activeJob.cancel() + p.activeJob.Cancel() select { case <-ctx.Done(): p.opts.Logger.Warn(ctx, "graceful shutdown failed", slog.Error(ctx.Err())) return ctx.Err() - case <-p.activeJob.isDone(): + case <-p.activeJob.Done(): p.opts.Logger.Info(ctx, "gracefully shutdown") return nil } @@ -355,9 +355,9 @@ func (p *Server) closeWithError(err error) error { if p.activeJob != nil { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - failErr := p.activeJob.fail(ctx, &proto.FailedJob{Error: errMsg}) + failErr := p.activeJob.Fail(ctx, &proto.FailedJob{Error: errMsg}) if failErr != nil { - p.activeJob.forceStop() + p.activeJob.ForceStop() } if err == nil { err = failErr diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index 11f560a19bf6e..88ea32d1cc6a5 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -29,17 +29,9 @@ const ( MissingParameterErrorText = "missing parameter" ) -type jobRunner interface { - start() - cancel() - isDone() <-chan any - fail(ctx context.Context, f *proto.FailedJob) error - forceStop() -} - -type runner struct { +type Runner struct { job *proto.AcquiredJob - sender messageSender + sender JobUpdater logger slog.Logger filesystem afero.Fs workDirectory string @@ -47,51 +39,52 @@ type runner struct { updateInterval time.Duration forceCancelInterval time.Duration - // mutex controls access to all the following variables. - mutex *sync.Mutex - // used to wait for the failedJob or completedJob to be populated - cond *sync.Cond - // closed when the job ends. - done chan any - failedJob *proto.FailedJob - completedJob *proto.CompletedJob + // closed when the Runner is finished sending any updates/failed/complete. + done chan any // active as long as we are not canceled gracefulContext context.Context cancelFunc context.CancelFunc // active as long as we haven't been force stopped forceStopContext context.Context forceStopFunc context.CancelFunc + + // mutex controls access to all the following variables. + mutex *sync.Mutex + // used to wait for the failedJob or completedJob to be populated + cond *sync.Cond + failedJob *proto.FailedJob + completedJob *proto.CompletedJob // setting this false signals that no more messages about this job should be sent. Usually this means that a // terminal message like FailedJob or CompletedJob has been sent, but if we are force canceled, we may set this // false and not send one. okToSend bool } -type messageSender interface { - updateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) - failJob(ctx context.Context, in *proto.FailedJob) error - completeJob(ctx context.Context, in *proto.CompletedJob) error +type JobUpdater interface { + UpdateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) + FailJob(ctx context.Context, in *proto.FailedJob) error + CompleteJob(ctx context.Context, in *proto.CompletedJob) error } -func newRunner( +func NewRunner( job *proto.AcquiredJob, - sender messageSender, + updater JobUpdater, logger slog.Logger, filesystem afero.Fs, workDirectory string, provisioner sdkproto.DRPCProvisionerClient, updateInterval time.Duration, - forceCancelInterval time.Duration) jobRunner { + forceCancelInterval time.Duration) *Runner { m := new(sync.Mutex) - // we need to create our contexts here in case a call to cancel() comes immediately. + // we need to create our contexts here in case a call to Cancel() comes immediately. logCtx := slog.With(context.Background(), slog.F("job_id", job.JobId)) forceStopContext, forceStopFunc := context.WithCancel(logCtx) gracefulContext, cancelFunc := context.WithCancel(forceStopContext) - return &runner{ + return &Runner{ job: job, - sender: sender, + sender: updater, logger: logger, filesystem: filesystem, workDirectory: workDirectory, @@ -109,16 +102,16 @@ func newRunner( } } -func (r *runner) start() { +func (r *Runner) Start() { r.mutex.Lock() defer r.mutex.Unlock() defer r.forceStopFunc() // the idea here is to run two goroutines to work on the job: run and heartbeat, then use the `r.cond` to wait until // the job is either complete or failed. This function then sends the complete or failed message --- the exception - // to this is if something calls fail() on the runner; either something external, like the server getting closed, - // or the heartbeat goroutine timing out after attempting to gracefully cancel. If something calls fail(), then - // the failure is sent on that goroutine on the context passed into fail(), and it marks okToSend false to signal + // to this is if something calls Fail() on the Runner; either something external, like the server getting closed, + // or the heartbeat goroutine timing out after attempting to gracefully cancel. If something calls Fail(), then + // the failure is sent on that goroutine on the context passed into Fail(), and it marks okToSend false to signal // us here that this function should not also send a terminal message. go r.run() go r.heartbeat() @@ -131,14 +124,14 @@ func (r *runner) start() { } if r.failedJob != nil { r.logger.Debug(r.forceStopContext, "sending FailedJob") - err := r.sender.failJob(r.forceStopContext, r.failedJob) + err := r.sender.FailJob(r.forceStopContext, r.failedJob) if err != nil { r.logger.Error(r.forceStopContext, "send FailJob", slog.Error(err)) } r.logger.Info(r.forceStopContext, "sent FailedJob") } else { r.logger.Debug(r.forceStopContext, "sending CompletedJob") - err := r.sender.completeJob(r.forceStopContext, r.completedJob) + err := r.sender.CompleteJob(r.forceStopContext, r.completedJob) if err != nil { r.logger.Error(r.forceStopContext, "send CompletedJob", slog.Error(err)) } @@ -148,35 +141,35 @@ func (r *runner) start() { r.okToSend = false } -// cancel initiates a cancel on the job, but allows it to keep running to do so gracefully. Read from isDone() to +// Cancel initiates a Cancel on the job, but allows it to keep running to do so gracefully. Read from Done() to // be notified when the job completes. -func (r *runner) cancel() { +func (r *Runner) Cancel() { r.cancelFunc() } -func (r *runner) isDone() <-chan any { +func (r *Runner) Done() <-chan any { return r.done } -// fail immediately halts updates and, if the job is not complete sends FailJob to the coder server. Running goroutines +// Fail immediately halts updates and, if the job is not complete sends FailJob to the coder server. Running goroutines // are canceled but complete asynchronously (although they are prevented from further updating the job to the coder // server). The provided context sets how long to keep trying to send the FailJob. -func (r *runner) fail(ctx context.Context, f *proto.FailedJob) error { +func (r *Runner) Fail(ctx context.Context, f *proto.FailedJob) error { f.JobId = r.job.JobId r.mutex.Lock() defer r.mutex.Unlock() if !r.okToSend { return nil // already done } - r.cancel() + r.Cancel() if r.failedJob == nil { r.failedJob = f r.cond.Signal() } // here we keep the original failed reason if there already was one, but we hadn't yet sent it. It is likely more // informative of the job failing due to some problem running it, whereas this function is used to externally - // force a fail. - err := r.sender.failJob(ctx, r.failedJob) + // force a Fail. + err := r.sender.FailJob(ctx, r.failedJob) r.okToSend = false r.forceStopFunc() close(r.done) @@ -184,7 +177,7 @@ func (r *runner) fail(ctx context.Context, f *proto.FailedJob) error { } // setComplete is an internal function to set the job to completed. This does not send the completedJob. -func (r *runner) setComplete(c *proto.CompletedJob) { +func (r *Runner) setComplete(c *proto.CompletedJob) { r.mutex.Lock() defer r.mutex.Unlock() if r.completedJob == nil { @@ -194,7 +187,7 @@ func (r *runner) setComplete(c *proto.CompletedJob) { } // setFail is an internal function to set the job to failed. This does not send the failedJob. -func (r *runner) setFail(f *proto.FailedJob) { +func (r *Runner) setFail(f *proto.FailedJob) { r.mutex.Lock() defer r.mutex.Unlock() if r.failedJob == nil { @@ -204,8 +197,8 @@ func (r *runner) setFail(f *proto.FailedJob) { } } -// forceStop signals all goroutines to stop and prevents any further API calls back to coder server for this job -func (r *runner) forceStop() { +// ForceStop signals all goroutines to stop and prevents any further API calls back to coder server for this job +func (r *Runner) ForceStop() { r.mutex.Lock() defer r.mutex.Unlock() r.forceStopFunc() @@ -215,21 +208,21 @@ func (r *runner) forceStop() { r.okToSend = false close(r.done) // doesn't matter what we put here, since it won't get sent! Just need something to satisfy the condition in - // start() + // Start() r.failedJob = &proto.FailedJob{} r.cond.Signal() } -func (r *runner) update(ctx context.Context, u *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { +func (r *Runner) update(ctx context.Context, u *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { r.mutex.Lock() defer r.mutex.Unlock() if !r.okToSend { return nil, xerrors.New("update skipped; job complete or failed") } - return r.sender.updateJob(ctx, u) + return r.sender.UpdateJob(ctx, u) } -func (r *runner) run() { +func (r *Runner) run() { // push the fail/succeed write onto the defer stack before the cleanup, so that cleanup happens before this. // Failures during this function should write to the _local_ failedJob variable, then return. var failedJob *proto.FailedJob @@ -369,7 +362,7 @@ func (r *runner) run() { } } -func (r *runner) heartbeat() { +func (r *Runner) heartbeat() { ticker := time.NewTicker(r.updateInterval) defer ticker.Stop() for { @@ -383,7 +376,7 @@ func (r *runner) heartbeat() { JobId: r.job.JobId, }) if err != nil { - err = r.fail(r.forceStopContext, r.failedJobf("send periodic update: %s", err)) + err = r.Fail(r.forceStopContext, r.failedJobf("send periodic update: %s", err)) if err != nil { r.logger.Error(r.forceStopContext, "failed to call FailJob", slog.Error(err)) } @@ -393,18 +386,18 @@ func (r *runner) heartbeat() { continue } r.logger.Info(r.forceStopContext, "attempting graceful cancelation") - r.cancel() + r.Cancel() // Hard-cancel the job after a minute of pending cancelation. timer := time.NewTimer(r.forceCancelInterval) select { case <-timer.C: - r.logger.Warn(r.forceStopContext, "cancel timed out") - err := r.fail(r.forceStopContext, r.failedJobf("cancel timed out")) + r.logger.Warn(r.forceStopContext, "Cancel timed out") + err := r.Fail(r.forceStopContext, r.failedJobf("Cancel timed out")) if err != nil { r.logger.Warn(r.forceStopContext, "failed to call FailJob", slog.Error(err)) } return - case <-r.isDone(): + case <-r.Done(): timer.Stop() return case <-r.forceStopContext.Done(): @@ -414,7 +407,7 @@ func (r *runner) heartbeat() { } } -func (r *runner) logCleanup(ctx context.Context) { +func (r *Runner) logCleanup(ctx context.Context) { _, err := r.update(ctx, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ @@ -434,7 +427,7 @@ func (r *runner) logCleanup(ctx context.Context) { // versions. const ReadmeFile = "README.md" -func (r *runner) runReadmeParse() *proto.FailedJob { +func (r *Runner) runReadmeParse() *proto.FailedJob { fi, err := afero.ReadFile(r.filesystem, path.Join(r.workDirectory, ReadmeFile)) if err != nil { _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ @@ -469,7 +462,7 @@ func (r *runner) runReadmeParse() *proto.FailedJob { return nil } -func (r *runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { +func (r *Runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { // Parse parameters and update the job with the parameter specs _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ JobId: r.job.JobId, @@ -502,7 +495,7 @@ func (r *runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { for _, parameterSchema := range parameterSchemas { _, ok := valueByName[parameterSchema.Name] if !ok { - return nil, r.failedJobf("%s: %s", provisionerd.missingParameterErrorText, parameterSchema.Name) + return nil, r.failedJobf("%s: %s", MissingParameterErrorText, parameterSchema.Name) } } @@ -560,7 +553,7 @@ func (r *runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { } // Parses parameter schemas from source. -func (r *runner) runTemplateImportParse() ([]*sdkproto.ParameterSchema, error) { +func (r *Runner) runTemplateImportParse() ([]*sdkproto.ParameterSchema, error) { stream, err := r.provisioner.Parse(r.forceStopContext, &sdkproto.Parse_Request{ Directory: r.workDirectory, }) @@ -608,7 +601,7 @@ func (r *runner) runTemplateImportParse() ([]*sdkproto.ParameterSchema, error) { // Performs a dry-run provision when importing a template. // This is used to detect resources that would be provisioned // for a workspace in various states. -func (r *runner) runTemplateImportProvision(values []*sdkproto.ParameterValue, metadata *sdkproto.Provision_Metadata) ([]*sdkproto.Resource, error) { +func (r *Runner) runTemplateImportProvision(values []*sdkproto.ParameterValue, metadata *sdkproto.Provision_Metadata) ([]*sdkproto.Resource, error) { var stage string switch metadata.WorkspaceTransition { case sdkproto.WorkspaceTransition_START: @@ -696,7 +689,7 @@ func (r *runner) runTemplateImportProvision(values []*sdkproto.ParameterValue, m } } -func (r *runner) runTemplateDryRun() ( +func (r *Runner) runTemplateDryRun() ( *proto.CompletedJob, *proto.FailedJob) { // Ensure all metadata fields are set as they are all optional for dry-run. metadata := r.job.GetTemplateDryRun().GetMetadata() @@ -745,7 +738,7 @@ func (r *runner) runTemplateDryRun() ( }, nil } -func (r *runner) runWorkspaceBuild() ( +func (r *Runner) runWorkspaceBuild() ( *proto.CompletedJob, *proto.FailedJob) { var stage string switch r.job.GetWorkspaceBuild().Metadata.WorkspaceTransition { @@ -868,7 +861,7 @@ func (r *runner) runWorkspaceBuild() ( } } -func (r *runner) failedJobf(format string, args ...interface{}) *proto.FailedJob { +func (r *Runner) failedJobf(format string, args ...interface{}) *proto.FailedJob { return &proto.FailedJob{ JobId: r.job.JobId, Error: fmt.Sprintf(format, args...), From cb07b64ec28832bebb5cd8c325a26794c2b7e701 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 30 Jun 2022 12:49:44 -0700 Subject: [PATCH 4/5] renames and slight reworking from code review Signed-off-by: Spike Curtis --- provisionerd/provisionerd.go | 5 +- provisionerd/provisionerd_test.go | 3 +- provisionerd/runner/runner.go | 209 +++++++++++++++--------------- 3 files changed, 108 insertions(+), 109 deletions(-) diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 62d4efc8579e6..222b95445b80a 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -9,12 +9,13 @@ import ( "sync" "time" - "cdr.dev/slog" "github.com/hashicorp/yamux" "github.com/spf13/afero" "go.uber.org/atomic" "golang.org/x/xerrors" + "cdr.dev/slog" + "github.com/coder/coder/provisionerd/proto" "github.com/coder/coder/provisionerd/runner" sdkproto "github.com/coder/coder/provisionersdk/proto" @@ -235,7 +236,7 @@ func (p *Server) acquireJob(ctx context.Context) { } p.activeJob = runner.NewRunner(job, p, p.opts.Logger, p.opts.Filesystem, p.opts.WorkDirectory, provisioner, p.opts.UpdateInterval, p.opts.ForceCancelInterval) - go p.activeJob.Start() + go p.activeJob.Run() } func retryable(err error) bool { diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 9b544494a1129..0b3626b8fae58 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -38,11 +38,10 @@ func TestMain(m *testing.M) { func closedWithin(c chan struct{}, d time.Duration) func() bool { return func() bool { - pop := time.After(d) select { case <-c: return true - case <-pop: + case <-time.After(d): return false } } diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index 88ea32d1cc6a5..f3e582024713a 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -42,11 +42,11 @@ type Runner struct { // closed when the Runner is finished sending any updates/failed/complete. done chan any // active as long as we are not canceled - gracefulContext context.Context - cancelFunc context.CancelFunc + notCanceled context.Context + cancel context.CancelFunc // active as long as we haven't been force stopped - forceStopContext context.Context - forceStopFunc context.CancelFunc + notStopped context.Context + stop context.CancelFunc // mutex controls access to all the following variables. mutex *sync.Mutex @@ -95,25 +95,28 @@ func NewRunner( cond: sync.NewCond(m), done: make(chan any), okToSend: true, - forceStopContext: forceStopContext, - forceStopFunc: forceStopFunc, - gracefulContext: gracefulContext, - cancelFunc: cancelFunc, + notStopped: forceStopContext, + stop: forceStopFunc, + notCanceled: gracefulContext, + cancel: cancelFunc, } } -func (r *Runner) Start() { +// Run the job. +// +// the idea here is to run two goroutines to work on the job: doCleanFinish and heartbeat, then use +// the `r.cond` to wait until the job is either complete or failed. This function then sends the +// complete or failed message --- the exception to this is if something calls Fail() on the Runner; +// either something external, like the server getting closed, or the heartbeat goroutine timing out +// after attempting to gracefully cancel. If something calls Fail(), then the failure is sent on +// that goroutine on the context passed into Fail(), and it marks okToSend false to signal us here +// that this function should not also send a terminal message. +func (r *Runner) Run() { r.mutex.Lock() defer r.mutex.Unlock() - defer r.forceStopFunc() - - // the idea here is to run two goroutines to work on the job: run and heartbeat, then use the `r.cond` to wait until - // the job is either complete or failed. This function then sends the complete or failed message --- the exception - // to this is if something calls Fail() on the Runner; either something external, like the server getting closed, - // or the heartbeat goroutine timing out after attempting to gracefully cancel. If something calls Fail(), then - // the failure is sent on that goroutine on the context passed into Fail(), and it marks okToSend false to signal - // us here that this function should not also send a terminal message. - go r.run() + defer r.stop() + + go r.doCleanFinish() go r.heartbeat() for r.failedJob == nil && r.completedJob == nil { r.cond.Wait() @@ -123,19 +126,19 @@ func (r *Runner) Start() { return } if r.failedJob != nil { - r.logger.Debug(r.forceStopContext, "sending FailedJob") - err := r.sender.FailJob(r.forceStopContext, r.failedJob) + r.logger.Debug(r.notStopped, "sending FailedJob") + err := r.sender.FailJob(r.notStopped, r.failedJob) if err != nil { - r.logger.Error(r.forceStopContext, "send FailJob", slog.Error(err)) + r.logger.Error(r.notStopped, "send FailJob", slog.Error(err)) } - r.logger.Info(r.forceStopContext, "sent FailedJob") + r.logger.Info(r.notStopped, "sent FailedJob") } else { - r.logger.Debug(r.forceStopContext, "sending CompletedJob") - err := r.sender.CompleteJob(r.forceStopContext, r.completedJob) + r.logger.Debug(r.notStopped, "sending CompletedJob") + err := r.sender.CompleteJob(r.notStopped, r.completedJob) if err != nil { - r.logger.Error(r.forceStopContext, "send CompletedJob", slog.Error(err)) + r.logger.Error(r.notStopped, "send CompletedJob", slog.Error(err)) } - r.logger.Info(r.forceStopContext, "sent CompletedJob") + r.logger.Info(r.notStopped, "sent CompletedJob") } close(r.done) r.okToSend = false @@ -144,7 +147,7 @@ func (r *Runner) Start() { // Cancel initiates a Cancel on the job, but allows it to keep running to do so gracefully. Read from Done() to // be notified when the job completes. func (r *Runner) Cancel() { - r.cancelFunc() + r.cancel() } func (r *Runner) Done() <-chan any { @@ -171,7 +174,7 @@ func (r *Runner) Fail(ctx context.Context, f *proto.FailedJob) error { // force a Fail. err := r.sender.FailJob(ctx, r.failedJob) r.okToSend = false - r.forceStopFunc() + r.stop() close(r.done) return err } @@ -191,7 +194,7 @@ func (r *Runner) setFail(f *proto.FailedJob) { r.mutex.Lock() defer r.mutex.Unlock() if r.failedJob == nil { - f.JobId = r.job.GetJobId() + f.JobId = r.job.JobId r.failedJob = f r.cond.Signal() } @@ -201,7 +204,7 @@ func (r *Runner) setFail(f *proto.FailedJob) { func (r *Runner) ForceStop() { r.mutex.Lock() defer r.mutex.Unlock() - r.forceStopFunc() + r.stop() if !r.okToSend { return } @@ -222,9 +225,10 @@ func (r *Runner) update(ctx context.Context, u *proto.UpdateJobRequest) (*proto. return r.sender.UpdateJob(ctx, u) } -func (r *Runner) run() { - // push the fail/succeed write onto the defer stack before the cleanup, so that cleanup happens before this. - // Failures during this function should write to the _local_ failedJob variable, then return. +// doCleanFinish wraps a call to do() with cleaning up the job and setting the terminal messages +func (r *Runner) doCleanFinish() { + // push the fail/succeed write onto the defer stack before the cleanup, so that cleanup happens + // before this. var failedJob *proto.FailedJob var completedJob *proto.CompletedJob defer func() { @@ -236,7 +240,19 @@ func (r *Runner) run() { }() defer func() { - r.logCleanup(r.forceStopContext) + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Cleaning Up", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + r.logger.Warn(r.notStopped, "failed to log cleanup") + return + } // Cleanup the work directory after execution. for attempt := 0; attempt < 5; attempt++ { @@ -246,22 +262,26 @@ func (r *Runner) run() { // When the provisioner daemon is shutting down, // it may take a few milliseconds for processes to exit. // See: https://github.com/golang/go/issues/50510 - r.logger.Debug(r.forceStopContext, "failed to clean work directory; trying again", slog.Error(err)) + r.logger.Debug(r.notStopped, "failed to clean work directory; trying again", slog.Error(err)) time.Sleep(250 * time.Millisecond) continue } - r.logger.Debug(r.forceStopContext, "cleaned up work directory", slog.Error(err)) + r.logger.Debug(r.notStopped, "cleaned up work directory", slog.Error(err)) break } }() + completedJob, failedJob = r.do() +} + +// do actually does the work of running the job +func (r *Runner) do() (*proto.CompletedJob, *proto.FailedJob) { err := r.filesystem.MkdirAll(r.workDirectory, 0700) if err != nil { - failedJob = r.failedJobf("create work directory %q: %s", r.workDirectory, err) - return + return nil, r.failedJobf("create work directory %q: %s", r.workDirectory, err) } - _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -271,11 +291,10 @@ func (r *Runner) run() { }}, }) if err != nil { - failedJob = r.failedJobf("write log: %s", err) - return + return nil, r.failedJobf("write log: %s", err) } - r.logger.Info(r.forceStopContext, "unpacking template source archive", + r.logger.Info(r.notStopped, "unpacking template source archive", slog.F("size_bytes", len(r.job.TemplateSourceArchive))) reader := tar.NewReader(bytes.NewBuffer(r.job.TemplateSourceArchive)) for { @@ -284,14 +303,12 @@ func (r *Runner) run() { break } if err != nil { - failedJob = r.failedJobf("read template source archive: %s", err) - return + return nil, r.failedJobf("read template source archive: %s", err) } // #nosec headerPath := filepath.Join(r.workDirectory, header.Name) if !strings.HasPrefix(headerPath, filepath.Clean(r.workDirectory)) { - failedJob = r.failedJobf("tar attempts to target relative upper directory") - return + return nil, r.failedJobf("tar attempts to target relative upper directory") } mode := header.FileInfo().Mode() if mode == 0 { @@ -301,15 +318,13 @@ func (r *Runner) run() { case tar.TypeDir: err = r.filesystem.MkdirAll(headerPath, mode) if err != nil { - failedJob = r.failedJobf("mkdir %q: %s", headerPath, err) - return + return nil, r.failedJobf("mkdir %q: %s", headerPath, err) } r.logger.Debug(context.Background(), "extracted directory", slog.F("path", headerPath)) case tar.TypeReg: file, err := r.filesystem.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode) if err != nil { - failedJob = r.failedJobf("create file %q (mode %s): %s", headerPath, mode, err) - return + return nil, r.failedJobf("create file %q (mode %s): %s", headerPath, mode, err) } // Max file size of 10MiB. size, err := io.CopyN(file, reader, 10<<20) @@ -318,13 +333,11 @@ func (r *Runner) run() { } if err != nil { _ = file.Close() - failedJob = r.failedJobf("copy file %q: %s", headerPath, err) - return + return nil, r.failedJobf("copy file %q: %s", headerPath, err) } err = file.Close() if err != nil { - failedJob = r.failedJobf("close file %q: %s", headerPath, err) - return + return nil, r.failedJobf("close file %q: %s", headerPath, err) } r.logger.Debug(context.Background(), "extracted file", slog.F("size_bytes", size), @@ -338,91 +351,77 @@ func (r *Runner) run() { case *proto.AcquiredJob_TemplateImport_: r.logger.Debug(context.Background(), "acquired job is template import") - failedJob = r.runReadmeParse() - if failedJob == nil { - completedJob, failedJob = r.runTemplateImport() + failedJob := r.runReadmeParse() + if failedJob != nil { + return nil, failedJob } + return r.runTemplateImport() case *proto.AcquiredJob_TemplateDryRun_: r.logger.Debug(context.Background(), "acquired job is template dry-run", slog.F("workspace_name", jobType.TemplateDryRun.Metadata.WorkspaceName), slog.F("parameters", jobType.TemplateDryRun.ParameterValues), ) - completedJob, failedJob = r.runTemplateDryRun() + return r.runTemplateDryRun() case *proto.AcquiredJob_WorkspaceBuild_: r.logger.Debug(context.Background(), "acquired job is workspace provision", slog.F("workspace_name", jobType.WorkspaceBuild.WorkspaceName), slog.F("state_length", len(jobType.WorkspaceBuild.State)), slog.F("parameters", jobType.WorkspaceBuild.ParameterValues), ) - - completedJob, failedJob = r.runWorkspaceBuild() + return r.runWorkspaceBuild() default: - failedJob = r.failedJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", + return nil, r.failedJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(r.job.Type).String()) } } +// heartbeat periodically sends updates on the job, which keeps coder server from assuming the job +// is stalled, and allows the runner to learn if the job has been canceled by the user. func (r *Runner) heartbeat() { ticker := time.NewTicker(r.updateInterval) defer ticker.Stop() for { select { - case <-r.gracefulContext.Done(): + case <-r.notCanceled.Done(): return case <-ticker.C: } - resp, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + resp, err := r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, }) if err != nil { - err = r.Fail(r.forceStopContext, r.failedJobf("send periodic update: %s", err)) + err = r.Fail(r.notStopped, r.failedJobf("send periodic update: %s", err)) if err != nil { - r.logger.Error(r.forceStopContext, "failed to call FailJob", slog.Error(err)) + r.logger.Error(r.notStopped, "failed to call FailJob", slog.Error(err)) } return } if !resp.Canceled { continue } - r.logger.Info(r.forceStopContext, "attempting graceful cancelation") + r.logger.Info(r.notStopped, "attempting graceful cancelation") r.Cancel() // Hard-cancel the job after a minute of pending cancelation. timer := time.NewTimer(r.forceCancelInterval) select { case <-timer.C: - r.logger.Warn(r.forceStopContext, "Cancel timed out") - err := r.Fail(r.forceStopContext, r.failedJobf("Cancel timed out")) + r.logger.Warn(r.notStopped, "Cancel timed out") + err := r.Fail(r.notStopped, r.failedJobf("Cancel timed out")) if err != nil { - r.logger.Warn(r.forceStopContext, "failed to call FailJob", slog.Error(err)) + r.logger.Warn(r.notStopped, "failed to call FailJob", slog.Error(err)) } return case <-r.Done(): timer.Stop() return - case <-r.forceStopContext.Done(): + case <-r.notStopped.Done(): timer.Stop() return } } } -func (r *Runner) logCleanup(ctx context.Context) { - _, err := r.update(ctx, &proto.UpdateJobRequest{ - JobId: r.job.JobId, - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Cleaning Up", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - r.logger.Warn(ctx, "failed to log cleanup") - return - } -} - // ReadmeFile is the location we look for to extract documentation from template // versions. const ReadmeFile = "README.md" @@ -430,7 +429,7 @@ const ReadmeFile = "README.md" func (r *Runner) runReadmeParse() *proto.FailedJob { fi, err := afero.ReadFile(r.filesystem, path.Join(r.workDirectory, ReadmeFile)) if err != nil { - _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -446,7 +445,7 @@ func (r *Runner) runReadmeParse() *proto.FailedJob { return nil } - _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -464,7 +463,7 @@ func (r *Runner) runReadmeParse() *proto.FailedJob { func (r *Runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { // Parse parameters and update the job with the parameter specs - _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -480,7 +479,7 @@ func (r *Runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { if err != nil { return nil, r.failedJobf("run parse: %s", err) } - updateResponse, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + updateResponse, err := r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, ParameterSchemas: parameterSchemas, }) @@ -500,7 +499,7 @@ func (r *Runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { } // Determine persistent resources - _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -521,7 +520,7 @@ func (r *Runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { } // Determine ephemeral resources. - _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -554,7 +553,7 @@ func (r *Runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { // Parses parameter schemas from source. func (r *Runner) runTemplateImportParse() ([]*sdkproto.ParameterSchema, error) { - stream, err := r.provisioner.Parse(r.forceStopContext, &sdkproto.Parse_Request{ + stream, err := r.provisioner.Parse(r.notStopped, &sdkproto.Parse_Request{ Directory: r.workDirectory, }) if err != nil { @@ -573,7 +572,7 @@ func (r *Runner) runTemplateImportParse() ([]*sdkproto.ParameterSchema, error) { slog.F("output", msgType.Log.Output), ) - _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, @@ -609,18 +608,18 @@ func (r *Runner) runTemplateImportProvision(values []*sdkproto.ParameterValue, m case sdkproto.WorkspaceTransition_STOP: stage = "Detecting ephemeral resources" } - // use the forceStopContext so that if we attempt to gracefully cancel, the stream will still be available for us + // use the notStopped so that if we attempt to gracefully cancel, the stream will still be available for us // to send the cancel to the provisioner - stream, err := r.provisioner.Provision(r.forceStopContext) + stream, err := r.provisioner.Provision(r.notStopped) if err != nil { return nil, xerrors.Errorf("provision: %w", err) } defer stream.Close() go func() { select { - case <-r.forceStopContext.Done(): + case <-r.notStopped.Done(): return - case <-r.gracefulContext.Done(): + case <-r.notCanceled.Done(): _ = stream.Send(&sdkproto.Provision_Request{ Type: &sdkproto.Provision_Request_Cancel{ Cancel: &sdkproto.Provision_Cancel{}, @@ -653,7 +652,7 @@ func (r *Runner) runTemplateImportProvision(values []*sdkproto.ParameterValue, m slog.F("level", msgType.Log.Level), slog.F("output", msgType.Log.Output), ) - _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, @@ -750,7 +749,7 @@ func (r *Runner) runWorkspaceBuild() ( stage = "Destroying workspace" } - _, err := r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER_DAEMON, @@ -763,18 +762,18 @@ func (r *Runner) runWorkspaceBuild() ( return nil, r.failedJobf("write log: %s", err) } - // use the forceStopContext so that if we attempt to gracefully cancel, the stream will still be available for us + // use the notStopped so that if we attempt to gracefully cancel, the stream will still be available for us // to send the cancel to the provisioner - stream, err := r.provisioner.Provision(r.forceStopContext) + stream, err := r.provisioner.Provision(r.notStopped) if err != nil { return nil, r.failedJobf("provision: %s", err) } defer stream.Close() go func() { select { - case <-r.forceStopContext.Done(): + case <-r.notStopped.Done(): return - case <-r.gracefulContext.Done(): + case <-r.notCanceled.Done(): _ = stream.Send(&sdkproto.Provision_Request{ Type: &sdkproto.Provision_Request_Cancel{ Cancel: &sdkproto.Provision_Cancel{}, @@ -809,7 +808,7 @@ func (r *Runner) runWorkspaceBuild() ( slog.F("workspace_build_id", r.job.GetWorkspaceBuild().WorkspaceBuildId), ) - _, err = r.update(r.forceStopContext, &proto.UpdateJobRequest{ + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ JobId: r.job.JobId, Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, From 164a3024fe52aa47af23a9605d8b2238a2ad78a6 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 30 Jun 2022 12:56:54 -0700 Subject: [PATCH 5/5] Reword comment about okToSend Signed-off-by: Spike Curtis --- provisionerd/runner/runner.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index f3e582024713a..f26808c00c824 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -54,9 +54,10 @@ type Runner struct { cond *sync.Cond failedJob *proto.FailedJob completedJob *proto.CompletedJob - // setting this false signals that no more messages about this job should be sent. Usually this means that a - // terminal message like FailedJob or CompletedJob has been sent, but if we are force canceled, we may set this - // false and not send one. + // setting this false signals that no more messages about this job should be sent. Usually this + // means that a terminal message like FailedJob or CompletedJob has been sent, even in the case + // of a Cancel(). However, when someone calls Fail() or ForceStop(), we might not send the + // terminal message, but okToSend is set to false regardless. okToSend bool }