diff --git a/agent/reaper/reaper_stub.go b/agent/reaper/reaper_stub.go index 538a7db71887a..8cd87ab0bf3a7 100644 --- a/agent/reaper/reaper_stub.go +++ b/agent/reaper/reaper_stub.go @@ -7,6 +7,6 @@ func IsInitProcess() bool { return false } -func ForkReap(opt ...Option) error { +func ForkReap(_ ...Option) error { return nil } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index d06007c50c3e4..db0d30ff8b23a 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -383,6 +383,7 @@ func UpdateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID // AwaitTemplateImportJob awaits for an import job to reach completed status. func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid.UUID) codersdk.TemplateVersion { + t.Logf("waiting for template version job %s", version) var templateVersion codersdk.TemplateVersion require.Eventually(t, func() bool { var err error diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 7a013e15d755d..626483994ddd9 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -331,6 +331,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto. if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } + server.Logger.Debug(ctx, "UpdateJob starting", slog.F("job_id", parsedID)) job, err := server.Database.GetProvisionerJobByID(ctx, parsedID) if err != nil { return nil, xerrors.Errorf("get job: %w", err) @@ -368,19 +369,27 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto. insertParams.Stage = append(insertParams.Stage, log.Stage) insertParams.Source = append(insertParams.Source, logSource) insertParams.Output = append(insertParams.Output, log.Output) + server.Logger.Debug(ctx, "job log", + slog.F("job_id", parsedID), + slog.F("stage", log.Stage), + slog.F("output", log.Output)) } logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams) if err != nil { + server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("insert job logs: %w", err) } + server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) data, err := json.Marshal(logs) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } err = server.Pubsub.Publish(provisionerJobLogsChannel(parsedID), data) if err != nil { + server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("publish job log: %w", err) } + server.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID)) } if len(request.Readme) > 0 { @@ -488,6 +497,7 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } + server.Logger.Debug(ctx, "FailJob starting", slog.F("job_id", jobID)) job, err := server.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get provisioner job: %w", err) @@ -547,6 +557,7 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } + server.Logger.Debug(ctx, "CompleteJob starting", slog.F("job_id", jobID)) job, err := server.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get job by id: %w", err) @@ -699,6 +710,7 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr reflect.TypeOf(completed.Type).String()) } + server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID)) return &proto.Empty{}, nil } diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 38d1efeae43df..8b163412f0ff4 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -130,6 +130,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job for _, log := range logs { select { case bufferedLogs <- log: + api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) default: // If this overflows users could miss logs streaming. This can happen // if a database request takes a long amount of time, and we get a lot of logs. @@ -176,8 +177,10 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job for { select { case <-r.Context().Done(): + api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID)) return case log := <-bufferedLogs: + api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) err = encoder.Encode(convertProvisionerJobLog(log)) if err != nil { return @@ -189,6 +192,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job continue } if job.CompletedAt.Valid { + api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID)) return } } diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index ad69bfc2bbfb9..222b95445b80a 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -1,41 +1,32 @@ package provisionerd import ( - "archive/tar" - "bytes" "context" "errors" "fmt" "io" - "os" - "path" - "path/filepath" - "reflect" "strings" "sync" "time" - "github.com/google/uuid" "github.com/hashicorp/yamux" "github.com/spf13/afero" "go.uber.org/atomic" "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionerd/runner" sdkproto "github.com/coder/coder/provisionersdk/proto" "github.com/coder/retry" ) -const ( - missingParameterErrorText = "missing parameter" -) - // IsMissingParameterError returns whether the error message provided // is a missing parameter error. This can indicate to consumers that // they should check parameters. func IsMissingParameterError(err string) bool { - return strings.Contains(err, missingParameterErrorText) + return strings.Contains(err, runner.MissingParameterErrorText) } // Dialer represents the function to create a daemon client connection. @@ -79,13 +70,8 @@ func New(clientDialer Dialer, opts *Options) *Server { closeCancel: ctxCancel, shutdown: make(chan struct{}), - - jobRunning: make(chan struct{}), - jobFailed: *atomic.NewBool(true), } - // Start off with a closed channel so - // isRunningJob() returns properly. - close(daemon.jobRunning) + go daemon.connect(ctx) return daemon } @@ -96,22 +82,13 @@ type Server struct { clientDialer Dialer clientValue atomic.Value - // Locked when closing the daemon. - closeMutex sync.Mutex + // Locked when closing the daemon, shutting down, or starting a new job. + mutex sync.Mutex closeContext context.Context closeCancel context.CancelFunc closeError error - - shutdownMutex sync.Mutex - shutdown chan struct{} - - // Locked when acquiring or failing a job. - jobMutex sync.Mutex - jobID string - jobRunningMutex sync.Mutex - jobRunning chan struct{} - jobFailed atomic.Bool - jobCancel context.CancelFunc + shutdown chan struct{} + activeJob *runner.Runner } // Connect establishes a connection to coderd. @@ -192,9 +169,13 @@ func (p *Server) client() (proto.DRPCProvisionerDaemonClient, bool) { return client, ok } +// isRunningJob returns true if a job is running. Caller must hold the mutex. func (p *Server) isRunningJob() bool { + if p.activeJob == nil { + return false + } select { - case <-p.jobRunning: + case <-p.activeJob.Done(): return false default: return true @@ -203,8 +184,8 @@ func (p *Server) isRunningJob() bool { // Locks a job in the database, and runs it! func (p *Server) acquireJob(ctx context.Context) { - p.jobMutex.Lock() - defer p.jobMutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if p.isClosed() { return } @@ -235,775 +216,78 @@ func (p *Server) acquireJob(ctx context.Context) { if job.JobId == "" { return } - ctx, p.jobCancel = context.WithCancel(ctx) - p.jobRunningMutex.Lock() - p.jobRunning = make(chan struct{}) - p.jobRunningMutex.Unlock() - p.jobFailed.Store(false) - p.jobID = job.JobId - p.opts.Logger.Info(context.Background(), "acquired job", slog.F("initiator_username", job.UserName), slog.F("provisioner", job.Provisioner), - slog.F("id", job.JobId), + slog.F("job_id", job.JobId), ) - go p.runJob(ctx, job) -} - -func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) { - shutdown, shutdownCancel := context.WithCancel(ctx) - defer shutdownCancel() - - complete, completeCancel := context.WithCancel(ctx) - defer completeCancel() - go func() { - ticker := time.NewTicker(p.opts.UpdateInterval) - defer ticker.Stop() - for { - select { - case <-p.closeContext.Done(): - return - case <-ctx.Done(): - return - case <-complete.Done(): - return - case <-p.shutdown: - p.opts.Logger.Info(ctx, "attempting graceful cancelation") - shutdownCancel() - return - case <-ticker.C: - } - client, ok := p.client() - if !ok { - continue - } - resp, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - }) - if errors.Is(err, yamux.ErrSessionShutdown) || errors.Is(err, io.EOF) { - continue - } - if err != nil { - p.failActiveJobf("send periodic update: %s", err) - return - } - if !resp.Canceled { - continue - } - p.opts.Logger.Info(ctx, "attempting graceful cancelation") - shutdownCancel() - // Hard-cancel the job after a minute of pending cancelation. - timer := time.NewTimer(p.opts.ForceCancelInterval) - select { - case <-timer.C: - p.failActiveJobf("cancelation timed out") - return - case <-ctx.Done(): - timer.Stop() - return - } - } - }() - defer func() { - // Cleanup the work directory after execution. - for attempt := 0; attempt < 5; attempt++ { - err := p.opts.Filesystem.RemoveAll(p.opts.WorkDirectory) - if err != nil { - // On Windows, open files cannot be removed. - // When the provisioner daemon is shutting down, - // it may take a few milliseconds for processes to exit. - // See: https://github.com/golang/go/issues/50510 - p.opts.Logger.Debug(ctx, "failed to clean work directory; trying again", slog.Error(err)) - time.Sleep(250 * time.Millisecond) - continue - } - p.opts.Logger.Debug(ctx, "cleaned up work directory", slog.Error(err)) - break - } - - close(p.jobRunning) - }() - // It's safe to cast this ProvisionerType. This data is coming directly from coderd. - provisioner, hasProvisioner := p.opts.Provisioners[job.Provisioner] - if !hasProvisioner { - p.failActiveJobf("provisioner %q not registered", job.Provisioner) - return - } - - err := p.opts.Filesystem.MkdirAll(p.opts.WorkDirectory, 0700) - if err != nil { - p.failActiveJobf("create work directory %q: %s", p.opts.WorkDirectory, err) - return - } - - client, ok := p.client() - if !ok { - p.failActiveJobf("client disconnected") - return - } - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Setting up", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - - p.opts.Logger.Info(ctx, "unpacking template source archive", slog.F("size_bytes", len(job.TemplateSourceArchive))) - reader := tar.NewReader(bytes.NewBuffer(job.TemplateSourceArchive)) - for { - header, err := reader.Next() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - p.failActiveJobf("read template source archive: %s", err) - return - } - // #nosec - headerPath := filepath.Join(p.opts.WorkDirectory, header.Name) - if !strings.HasPrefix(headerPath, filepath.Clean(p.opts.WorkDirectory)) { - p.failActiveJobf("tar attempts to target relative upper directory") - return - } - mode := header.FileInfo().Mode() - if mode == 0 { - mode = 0600 - } - switch header.Typeflag { - case tar.TypeDir: - err = p.opts.Filesystem.MkdirAll(headerPath, mode) - if err != nil { - p.failActiveJobf("mkdir %q: %s", headerPath, err) - return - } - p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", headerPath)) - case tar.TypeReg: - file, err := p.opts.Filesystem.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode) - if err != nil { - p.failActiveJobf("create file %q (mode %s): %s", headerPath, mode, err) - return - } - // Max file size of 10MiB. - size, err := io.CopyN(file, reader, 10<<20) - if errors.Is(err, io.EOF) { - err = nil - } - if err != nil { - _ = file.Close() - p.failActiveJobf("copy file %q: %s", headerPath, err) - return - } - err = file.Close() - if err != nil { - p.failActiveJobf("close file %q: %s", headerPath, err) - return - } - p.opts.Logger.Debug(context.Background(), "extracted file", - slog.F("size_bytes", size), - slog.F("path", headerPath), - slog.F("mode", mode), - ) - } - } - - switch jobType := job.Type.(type) { - case *proto.AcquiredJob_TemplateImport_: - p.opts.Logger.Debug(context.Background(), "acquired job is template import") - - p.runReadmeParse(ctx, job) - p.runTemplateImport(ctx, shutdown, provisioner, job) - case *proto.AcquiredJob_TemplateDryRun_: - p.opts.Logger.Debug(context.Background(), "acquired job is template dry-run", - slog.F("workspace_name", jobType.TemplateDryRun.Metadata.WorkspaceName), - slog.F("parameters", jobType.TemplateDryRun.ParameterValues), - ) - p.runTemplateDryRun(ctx, shutdown, provisioner, job) - case *proto.AcquiredJob_WorkspaceBuild_: - p.opts.Logger.Debug(context.Background(), "acquired job is workspace provision", - slog.F("workspace_name", jobType.WorkspaceBuild.WorkspaceName), - slog.F("state_length", len(jobType.WorkspaceBuild.State)), - slog.F("parameters", jobType.WorkspaceBuild.ParameterValues), - ) - - p.runWorkspaceBuild(ctx, shutdown, provisioner, job) - default: - p.failActiveJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(job.Type).String()) - return - } - - client, ok = p.client() - if !ok { - return - } - // Ensure the job is still running to output. - // It's possible the job has failed. - if p.isRunningJob() { - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Cleaning Up", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - - p.opts.Logger.Info(context.Background(), "completed job", slog.F("id", job.JobId)) - } -} - -// ReadmeFile is the location we look for to extract documentation from template -// versions. -const ReadmeFile = "README.md" - -func (p *Server) runReadmeParse(ctx context.Context, job *proto.AcquiredJob) { - client, ok := p.client() + provisioner, ok := p.opts.Provisioners[job.Provisioner] if !ok { - p.failActiveJobf("client disconnected") - return - } - - fi, err := afero.ReadFile(p.opts.Filesystem, path.Join(p.opts.WorkDirectory, ReadmeFile)) - if err != nil { - _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_DEBUG, - Stage: "No README.md provided", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, + err := p.FailJob(ctx, &proto.FailedJob{ + JobId: job.JobId, + Error: fmt.Sprintf("no provisioner %s", job.Provisioner), }) if err != nil { - p.failActiveJobf("write log: %s", err) + p.opts.Logger.Error(context.Background(), "failed to call FailJob", + slog.F("job_id", job.JobId), slog.Error(err)) } - - return - } - - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Adding README.md...", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - Readme: fi, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) return } + p.activeJob = runner.NewRunner(job, p, p.opts.Logger, p.opts.Filesystem, p.opts.WorkDirectory, provisioner, + p.opts.UpdateInterval, p.opts.ForceCancelInterval) + go p.activeJob.Run() } -func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { - client, ok := p.client() - if !ok { - p.failActiveJobf("client disconnected") - return - } - - // Parse parameters and update the job with the parameter specs - _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Parsing template parameters", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - parameterSchemas, err := p.runTemplateImportParse(ctx, provisioner, job) - if err != nil { - p.failActiveJobf("run parse: %s", err) - return - } - updateResponse, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - ParameterSchemas: parameterSchemas, - }) - if err != nil { - p.failActiveJobf("update job: %s", err) - return - } - - valueByName := map[string]*sdkproto.ParameterValue{} - for _, parameterValue := range updateResponse.ParameterValues { - valueByName[parameterValue.Name] = parameterValue - } - for _, parameterSchema := range parameterSchemas { - _, ok := valueByName[parameterSchema.Name] - if !ok { - p.failActiveJobf("%s: %s", missingParameterErrorText, parameterSchema.Name) - return - } - } - - // Determine persistent resources - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Detecting persistent resources", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - startResources, err := p.runTemplateImportProvision(ctx, shutdown, provisioner, job, updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ - CoderUrl: job.GetTemplateImport().Metadata.CoderUrl, - WorkspaceTransition: sdkproto.WorkspaceTransition_START, - }) - if err != nil { - p.failActiveJobf("template import provision for start: %s", err) - return - } - - // Determine ephemeral resources. - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: "Detecting ephemeral resources", - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - stopResources, err := p.runTemplateImportProvision(ctx, shutdown, provisioner, job, updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ - CoderUrl: job.GetTemplateImport().Metadata.CoderUrl, - WorkspaceTransition: sdkproto.WorkspaceTransition_STOP, - }) - if err != nil { - p.failActiveJobf("template import provision for stop: %s", err) - return - } - - p.completeJob(&proto.CompletedJob{ - JobId: job.JobId, - Type: &proto.CompletedJob_TemplateImport_{ - TemplateImport: &proto.CompletedJob_TemplateImport{ - StartResources: startResources, - StopResources: stopResources, - }, - }, - }) +func retryable(err error) bool { + return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) || + // annoyingly, dRPC sometimes returns context.Canceled if the transport was closed, even if the context for + // the RPC *is not canceled*. Retrying is fine if the RPC context is not canceled. + xerrors.Is(err, context.Canceled) } -// Parses parameter schemas from source. -func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) ([]*sdkproto.ParameterSchema, error) { - client, ok := p.client() - if !ok { - return nil, xerrors.New("client disconnected") - } - stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{ - Directory: p.opts.WorkDirectory, - }) - if err != nil { - return nil, xerrors.Errorf("parse source: %w", err) - } - defer stream.Close() - for { - msg, err := stream.Recv() - if err != nil { - return nil, xerrors.Errorf("recv parse source: %w", err) +// clientDoWithRetries runs the function f with a client, and retries with backoff until either the error returned +// is not retryable() or the context expires. +func (p *Server) clientDoWithRetries( + ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) ( + any, error) { + for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); { + client, ok := p.client() + if !ok { + continue } - switch msgType := msg.Type.(type) { - case *sdkproto.Parse_Response_Log: - p.opts.Logger.Debug(context.Background(), "parse job logged", - slog.F("level", msgType.Log.Level), - slog.F("output", msgType.Log.Output), - ) - - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER, - Level: msgType.Log.Level, - CreatedAt: time.Now().UTC().UnixMilli(), - Output: msgType.Log.Output, - Stage: "Parse parameters", - }}, - }) - if err != nil { - return nil, xerrors.Errorf("update job: %w", err) - } - case *sdkproto.Parse_Response_Complete: - p.opts.Logger.Info(context.Background(), "parse complete", - slog.F("parameter_schemas", msgType.Complete.ParameterSchemas)) - - return msgType.Complete.ParameterSchemas, nil - default: - return nil, xerrors.Errorf("invalid message type %q received from provisioner", - reflect.TypeOf(msg.Type).String()) + resp, err := f(ctx, client) + if retryable(err) { + continue } + return resp, err } + return nil, ctx.Err() } -// Performs a dry-run provision when importing a template. -// This is used to detect resources that would be provisioned -// for a workspace in various states. -func (p *Server) runTemplateImportProvision(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob, values []*sdkproto.ParameterValue, metadata *sdkproto.Provision_Metadata) ([]*sdkproto.Resource, error) { - var stage string - switch metadata.WorkspaceTransition { - case sdkproto.WorkspaceTransition_START: - stage = "Detecting persistent resources" - case sdkproto.WorkspaceTransition_STOP: - stage = "Detecting ephemeral resources" - } - stream, err := provisioner.Provision(ctx) - if err != nil { - return nil, xerrors.Errorf("provision: %w", err) - } - defer stream.Close() - go func() { - select { - case <-ctx.Done(): - return - case <-shutdown.Done(): - _ = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Cancel{ - Cancel: &sdkproto.Provision_Cancel{}, - }, - }) - } - }() - err = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Start{ - Start: &sdkproto.Provision_Start{ - Directory: p.opts.WorkDirectory, - ParameterValues: values, - DryRun: true, - Metadata: metadata, - }, - }, +func (p *Server) UpdateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + out, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { + return client.UpdateJob(ctx, in) }) if err != nil { - return nil, xerrors.Errorf("start provision: %w", err) - } - - for { - msg, err := stream.Recv() - if err != nil { - return nil, xerrors.Errorf("recv import provision: %w", err) - } - switch msgType := msg.Type.(type) { - case *sdkproto.Provision_Response_Log: - p.opts.Logger.Debug(context.Background(), "template import provision job logged", - slog.F("level", msgType.Log.Level), - slog.F("output", msgType.Log.Output), - ) - client, ok := p.client() - if !ok { - continue - } - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER, - Level: msgType.Log.Level, - CreatedAt: time.Now().UTC().UnixMilli(), - Output: msgType.Log.Output, - Stage: stage, - }}, - }) - if err != nil { - return nil, xerrors.Errorf("send job update: %w", err) - } - case *sdkproto.Provision_Response_Complete: - if msgType.Complete.Error != "" { - p.opts.Logger.Info(context.Background(), "dry-run provision failure", - slog.F("error", msgType.Complete.Error), - ) - - return nil, xerrors.New(msgType.Complete.Error) - } - - p.opts.Logger.Info(context.Background(), "parse dry-run provision successful", - slog.F("resource_count", len(msgType.Complete.Resources)), - slog.F("resources", msgType.Complete.Resources), - slog.F("state_length", len(msgType.Complete.State)), - ) - - return msgType.Complete.Resources, nil - default: - return nil, xerrors.Errorf("invalid message type %q received from provisioner", - reflect.TypeOf(msg.Type).String()) - } + return nil, err } + // nolint: forcetypeassert + return out.(*proto.UpdateJobResponse), nil } -func (p *Server) runTemplateDryRun(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { - // Ensure all metadata fields are set as they are all optional for dry-run. - metadata := job.GetTemplateDryRun().GetMetadata() - metadata.WorkspaceTransition = sdkproto.WorkspaceTransition_START - if metadata.CoderUrl == "" { - metadata.CoderUrl = "http://localhost:3000" - } - if metadata.WorkspaceName == "" { - metadata.WorkspaceName = "dryrun" - } - metadata.WorkspaceOwner = job.UserName - if metadata.WorkspaceOwner == "" { - metadata.WorkspaceOwner = "dryrunner" - } - if metadata.WorkspaceId == "" { - id, err := uuid.NewRandom() - if err != nil { - p.failActiveJobf("generate random ID: %s", err) - return - } - metadata.WorkspaceId = id.String() - } - if metadata.WorkspaceOwnerId == "" { - id, err := uuid.NewRandom() - if err != nil { - p.failActiveJobf("generate random ID: %s", err) - return - } - metadata.WorkspaceOwnerId = id.String() - } - - // Run the template import provision task since it's already a dry run. - resources, err := p.runTemplateImportProvision(ctx, - shutdown, - provisioner, - job, - job.GetTemplateDryRun().GetParameterValues(), - metadata, - ) - if err != nil { - p.failActiveJobf("run dry-run provision job: %s", err) - return - } - - p.completeJob(&proto.CompletedJob{ - JobId: job.JobId, - Type: &proto.CompletedJob_TemplateDryRun_{ - TemplateDryRun: &proto.CompletedJob_TemplateDryRun{ - Resources: resources, - }, - }, +func (p *Server) FailJob(ctx context.Context, in *proto.FailedJob) error { + _, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { + return client.FailJob(ctx, in) }) + return err } -func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { - var stage string - switch job.GetWorkspaceBuild().Metadata.WorkspaceTransition { - case sdkproto.WorkspaceTransition_START: - stage = "Starting workspace" - case sdkproto.WorkspaceTransition_STOP: - stage = "Stopping workspace" - case sdkproto.WorkspaceTransition_DESTROY: - stage = "Destroying workspace" - } - - client, ok := p.client() - if !ok { - p.failActiveJobf("client disconnected") - return - } - _, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.GetJobId(), - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER_DAEMON, - Level: sdkproto.LogLevel_INFO, - Stage: stage, - CreatedAt: time.Now().UTC().UnixMilli(), - }}, - }) - if err != nil { - p.failActiveJobf("write log: %s", err) - return - } - - stream, err := provisioner.Provision(ctx) - if err != nil { - p.failActiveJobf("provision: %s", err) - return - } - defer stream.Close() - go func() { - select { - case <-ctx.Done(): - return - case <-shutdown.Done(): - _ = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Cancel{ - Cancel: &sdkproto.Provision_Cancel{}, - }, - }) - } - }() - err = stream.Send(&sdkproto.Provision_Request{ - Type: &sdkproto.Provision_Request_Start{ - Start: &sdkproto.Provision_Start{ - Directory: p.opts.WorkDirectory, - ParameterValues: job.GetWorkspaceBuild().ParameterValues, - Metadata: job.GetWorkspaceBuild().Metadata, - State: job.GetWorkspaceBuild().State, - }, - }, - }) - if err != nil { - p.failActiveJobf("start provision: %s", err) - return - } - - for { - msg, err := stream.Recv() - if err != nil { - p.failActiveJobf("recv workspace provision: %s", err) - return - } - switch msgType := msg.Type.(type) { - case *sdkproto.Provision_Response_Log: - p.opts.Logger.Debug(context.Background(), "workspace provision job logged", - slog.F("level", msgType.Log.Level), - slog.F("output", msgType.Log.Output), - slog.F("workspace_build_id", job.GetWorkspaceBuild().WorkspaceBuildId), - ) - - _, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{ - JobId: job.JobId, - Logs: []*proto.Log{{ - Source: proto.LogSource_PROVISIONER, - Level: msgType.Log.Level, - CreatedAt: time.Now().UTC().UnixMilli(), - Output: msgType.Log.Output, - Stage: stage, - }}, - }) - if err != nil { - p.failActiveJobf("send job update: %s", err) - return - } - case *sdkproto.Provision_Response_Complete: - if msgType.Complete.Error != "" { - p.opts.Logger.Info(context.Background(), "provision failed; updating state", - slog.F("state_length", len(msgType.Complete.State)), - ) - - p.failActiveJob(&proto.FailedJob{ - Error: msgType.Complete.Error, - Type: &proto.FailedJob_WorkspaceBuild_{ - WorkspaceBuild: &proto.FailedJob_WorkspaceBuild{ - State: msgType.Complete.State, - }, - }, - }) - return - } - - p.completeJob(&proto.CompletedJob{ - JobId: job.JobId, - Type: &proto.CompletedJob_WorkspaceBuild_{ - WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ - State: msgType.Complete.State, - Resources: msgType.Complete.Resources, - }, - }, - }) - p.opts.Logger.Info(context.Background(), "provision successful; marked job as complete", - slog.F("resource_count", len(msgType.Complete.Resources)), - slog.F("resources", msgType.Complete.Resources), - slog.F("state_length", len(msgType.Complete.State)), - ) - // Stop looping! - return - default: - p.failActiveJobf("invalid message type %T received from provisioner", msg.Type) - return - } - } -} - -func (p *Server) completeJob(job *proto.CompletedJob) { - for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); { - client, ok := p.client() - if !ok { - continue - } - // Complete job may need to be async if we disconnected... - // When we reconnect we can flush any of these cached values. - _, err := client.CompleteJob(p.closeContext, job) - if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) { - continue - } - if err != nil { - p.opts.Logger.Warn(p.closeContext, "failed to complete job", slog.Error(err)) - p.failActiveJobf(err.Error()) - return - } - break - } -} - -func (p *Server) failActiveJobf(format string, args ...interface{}) { - p.failActiveJob(&proto.FailedJob{ - Error: fmt.Sprintf(format, args...), +func (p *Server) CompleteJob(ctx context.Context, in *proto.CompletedJob) error { + _, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) { + return client.CompleteJob(ctx, in) }) -} - -func (p *Server) failActiveJob(failedJob *proto.FailedJob) { - p.jobMutex.Lock() - defer p.jobMutex.Unlock() - if !p.isRunningJob() { - return - } - if p.jobFailed.Load() { - p.opts.Logger.Debug(context.Background(), "job has already been marked as failed", slog.F("error_messsage", failedJob.Error)) - return - } - p.jobFailed.Store(true) - p.jobCancel() - p.opts.Logger.Info(context.Background(), "failing running job", - slog.F("error_message", failedJob.Error), - slog.F("job_id", p.jobID), - ) - failedJob.JobId = p.jobID - for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); { - client, ok := p.client() - if !ok { - continue - } - _, err := client.FailJob(p.closeContext, failedJob) - if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) { - continue - } - if err != nil { - if p.isClosed() { - return - } - p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err)) - return - } - p.opts.Logger.Debug(context.Background(), "marked running job as failed") - return - } + return err } // isClosed returns whether the API is closed or not. @@ -1029,18 +313,23 @@ func (p *Server) isShutdown() bool { // Shutdown triggers a graceful exit of each registered provisioner. // It exits when an active job stops. func (p *Server) Shutdown(ctx context.Context) error { - p.shutdownMutex.Lock() - defer p.shutdownMutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if !p.isRunningJob() { return nil } p.opts.Logger.Info(ctx, "attempting graceful shutdown") close(p.shutdown) + if p.activeJob == nil { + return nil + } + // wait for active job + p.activeJob.Cancel() select { case <-ctx.Done(): p.opts.Logger.Warn(ctx, "graceful shutdown failed", slog.Error(ctx.Err())) return ctx.Err() - case <-p.jobRunning: + case <-p.activeJob.Done(): p.opts.Logger.Info(ctx, "gracefully shutdown") return nil } @@ -1053,8 +342,8 @@ func (p *Server) Close() error { // closeWithError closes the provisioner; subsequent reads/writes will return the error err. func (p *Server) closeWithError(err error) error { - p.closeMutex.Lock() - defer p.closeMutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if p.isClosed() { return p.closeError } @@ -1064,10 +353,18 @@ func (p *Server) closeWithError(err error) error { if err != nil { errMsg = err.Error() } - p.failActiveJobf(errMsg) - p.jobRunningMutex.Lock() - <-p.jobRunning - p.jobRunningMutex.Unlock() + if p.activeJob != nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + failErr := p.activeJob.Fail(ctx, &proto.FailedJob{Error: errMsg}) + if failErr != nil { + p.activeJob.ForceStop() + } + if err == nil { + err = failErr + } + } + p.closeCancel() p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 0031d662311f4..0b3626b8fae58 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -4,6 +4,7 @@ import ( "archive/tar" "bytes" "context" + "fmt" "io" "os" "path/filepath" @@ -11,7 +12,10 @@ import ( "testing" "time" + "github.com/coder/coder/provisionerd/runner" + "github.com/hashicorp/yamux" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" @@ -32,6 +36,17 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } +func closedWithin(c chan struct{}, d time.Duration) func() bool { + return func() bool { + select { + case <-c: + return true + case <-time.After(d): + return false + } + } +} + func TestProvisionerd(t *testing.T) { t.Parallel() @@ -54,7 +69,7 @@ func TestProvisionerd(t *testing.T) { defer close(completeChan) return nil, xerrors.New("an error") }, provisionerd.Provisioners{}) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -77,7 +92,7 @@ func TestProvisionerd(t *testing.T) { updateJob: noopUpdateJob, }), nil }, provisionerd.Provisioners{}) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -122,7 +137,7 @@ func TestProvisionerd(t *testing.T) { }), }) closerMutex.Unlock() - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -160,7 +175,7 @@ func TestProvisionerd(t *testing.T) { }, provisionerd.Provisioners{ "someprovisioner": createProvisionerClient(t, provisionerTestServer{}), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -203,7 +218,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, closer.Close()) }) @@ -231,8 +246,8 @@ func TestProvisionerd(t *testing.T) { JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ - "test.txt": "content", - provisionerd.ReadmeFile: "# A cool template 😎\n", + "test.txt": "content", + runner.ReadmeFile: "# A cool template 😎\n", }), Type: &proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ @@ -308,7 +323,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didLog.Load()) require.True(t, didComplete.Load()) require.True(t, didDryRun.Load()) @@ -388,7 +403,7 @@ func TestProvisionerd(t *testing.T) { }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didLog.Load()) require.True(t, didComplete.Load()) require.NoError(t, closer.Close()) @@ -459,7 +474,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didLog.Load()) require.True(t, didComplete.Load()) require.NoError(t, closer.Close()) @@ -514,7 +529,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.True(t, didFail.Load()) require.NoError(t, closer.Close()) }) @@ -587,10 +602,10 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-updateChan + require.Condition(t, closedWithin(updateChan, 5*time.Second)) err := server.Shutdown(context.Background()) require.NoError(t, err) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, server.Close()) }) @@ -617,15 +632,21 @@ func TestProvisionerd(t *testing.T) { }, nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + resp := &proto.UpdateJobResponse{} if len(update.Logs) > 0 && update.Logs[0].Source == proto.LogSource_PROVISIONER { // Close on a log so we know when the job is in progress! updated.Do(func() { close(updateChan) }) } - return &proto.UpdateJobResponse{ - Canceled: true, - }, nil + // start returning Canceled once we've gotten at least one log. + select { + case <-updateChan: + resp.Canceled = true + default: + // pass + } + return resp, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { completed.Do(func() { @@ -664,8 +685,8 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-updateChan - <-completeChan + require.Condition(t, closedWithin(updateChan, 5*time.Second)) + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, server.Close()) }) @@ -703,6 +724,7 @@ func TestProvisionerd(t *testing.T) { return &proto.UpdateJobResponse{}, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { + assert.Equal(t, job.JobId, "test") if second.Load() { completeOnce.Do(func() { close(completeChan) }) return &proto.Empty{}, nil @@ -736,7 +758,7 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) require.NoError(t, server.Close()) }) @@ -808,7 +830,96 @@ func TestProvisionerd(t *testing.T) { }, }), }) - <-completeChan + require.Condition(t, closedWithin(completeChan, 5*time.Second)) + require.NoError(t, server.Close()) + }) + + t.Run("UpdatesBeforeComplete", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + m := sync.Mutex{} + var ops []string + completeChan := make(chan struct{}) + completeOnce := sync.Once{} + + server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "AcquiredJob called.") + if len(ops) > 0 { + return &proto.AcquiredJob{}, nil + } + ops = append(ops, "AcquireJob") + + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + TemplateSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ + Metadata: &sdkproto.Provision_Metadata{}, + }, + }, + }, nil + }, + updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "UpdateJob called.") + ops = append(ops, "UpdateJob") + for _, log := range update.Logs { + ops = append(ops, fmt.Sprintf("Log: %s | %s", log.Stage, log.Output)) + } + return &proto.UpdateJobResponse{}, nil + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "CompleteJob called.") + ops = append(ops, "CompleteJob") + completeOnce.Do(func() { close(completeChan) }) + return &proto.Empty{}, nil + }, + failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { + m.Lock() + defer m.Unlock() + logger.Info(ctx, "FailJob called.") + ops = append(ops, "FailJob") + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error { + err := stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_DEBUG, + Output: "wow", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Complete{ + Complete: &sdkproto.Provision_Complete{}, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + require.Condition(t, closedWithin(completeChan, 5*time.Second)) + m.Lock() + defer m.Unlock() + require.Equal(t, ops[len(ops)-1], "CompleteJob") + require.Contains(t, ops[0:len(ops)-1], "Log: Cleaning Up | ") require.NoError(t, server.Close()) }) } diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go new file mode 100644 index 0000000000000..f26808c00c824 --- /dev/null +++ b/provisionerd/runner/runner.go @@ -0,0 +1,869 @@ +package runner + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "reflect" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/spf13/afero" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +const ( + MissingParameterErrorText = "missing parameter" +) + +type Runner struct { + job *proto.AcquiredJob + sender JobUpdater + logger slog.Logger + filesystem afero.Fs + workDirectory string + provisioner sdkproto.DRPCProvisionerClient + updateInterval time.Duration + forceCancelInterval time.Duration + + // closed when the Runner is finished sending any updates/failed/complete. + done chan any + // active as long as we are not canceled + notCanceled context.Context + cancel context.CancelFunc + // active as long as we haven't been force stopped + notStopped context.Context + stop context.CancelFunc + + // mutex controls access to all the following variables. + mutex *sync.Mutex + // used to wait for the failedJob or completedJob to be populated + cond *sync.Cond + failedJob *proto.FailedJob + completedJob *proto.CompletedJob + // setting this false signals that no more messages about this job should be sent. Usually this + // means that a terminal message like FailedJob or CompletedJob has been sent, even in the case + // of a Cancel(). However, when someone calls Fail() or ForceStop(), we might not send the + // terminal message, but okToSend is set to false regardless. + okToSend bool +} + +type JobUpdater interface { + UpdateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) + FailJob(ctx context.Context, in *proto.FailedJob) error + CompleteJob(ctx context.Context, in *proto.CompletedJob) error +} + +func NewRunner( + job *proto.AcquiredJob, + updater JobUpdater, + logger slog.Logger, + filesystem afero.Fs, + workDirectory string, + provisioner sdkproto.DRPCProvisionerClient, + updateInterval time.Duration, + forceCancelInterval time.Duration) *Runner { + m := new(sync.Mutex) + + // we need to create our contexts here in case a call to Cancel() comes immediately. + logCtx := slog.With(context.Background(), slog.F("job_id", job.JobId)) + forceStopContext, forceStopFunc := context.WithCancel(logCtx) + gracefulContext, cancelFunc := context.WithCancel(forceStopContext) + + return &Runner{ + job: job, + sender: updater, + logger: logger, + filesystem: filesystem, + workDirectory: workDirectory, + provisioner: provisioner, + updateInterval: updateInterval, + forceCancelInterval: forceCancelInterval, + mutex: m, + cond: sync.NewCond(m), + done: make(chan any), + okToSend: true, + notStopped: forceStopContext, + stop: forceStopFunc, + notCanceled: gracefulContext, + cancel: cancelFunc, + } +} + +// Run the job. +// +// the idea here is to run two goroutines to work on the job: doCleanFinish and heartbeat, then use +// the `r.cond` to wait until the job is either complete or failed. This function then sends the +// complete or failed message --- the exception to this is if something calls Fail() on the Runner; +// either something external, like the server getting closed, or the heartbeat goroutine timing out +// after attempting to gracefully cancel. If something calls Fail(), then the failure is sent on +// that goroutine on the context passed into Fail(), and it marks okToSend false to signal us here +// that this function should not also send a terminal message. +func (r *Runner) Run() { + r.mutex.Lock() + defer r.mutex.Unlock() + defer r.stop() + + go r.doCleanFinish() + go r.heartbeat() + for r.failedJob == nil && r.completedJob == nil { + r.cond.Wait() + } + if !r.okToSend { + // nothing else to do. + return + } + if r.failedJob != nil { + r.logger.Debug(r.notStopped, "sending FailedJob") + err := r.sender.FailJob(r.notStopped, r.failedJob) + if err != nil { + r.logger.Error(r.notStopped, "send FailJob", slog.Error(err)) + } + r.logger.Info(r.notStopped, "sent FailedJob") + } else { + r.logger.Debug(r.notStopped, "sending CompletedJob") + err := r.sender.CompleteJob(r.notStopped, r.completedJob) + if err != nil { + r.logger.Error(r.notStopped, "send CompletedJob", slog.Error(err)) + } + r.logger.Info(r.notStopped, "sent CompletedJob") + } + close(r.done) + r.okToSend = false +} + +// Cancel initiates a Cancel on the job, but allows it to keep running to do so gracefully. Read from Done() to +// be notified when the job completes. +func (r *Runner) Cancel() { + r.cancel() +} + +func (r *Runner) Done() <-chan any { + return r.done +} + +// Fail immediately halts updates and, if the job is not complete sends FailJob to the coder server. Running goroutines +// are canceled but complete asynchronously (although they are prevented from further updating the job to the coder +// server). The provided context sets how long to keep trying to send the FailJob. +func (r *Runner) Fail(ctx context.Context, f *proto.FailedJob) error { + f.JobId = r.job.JobId + r.mutex.Lock() + defer r.mutex.Unlock() + if !r.okToSend { + return nil // already done + } + r.Cancel() + if r.failedJob == nil { + r.failedJob = f + r.cond.Signal() + } + // here we keep the original failed reason if there already was one, but we hadn't yet sent it. It is likely more + // informative of the job failing due to some problem running it, whereas this function is used to externally + // force a Fail. + err := r.sender.FailJob(ctx, r.failedJob) + r.okToSend = false + r.stop() + close(r.done) + return err +} + +// setComplete is an internal function to set the job to completed. This does not send the completedJob. +func (r *Runner) setComplete(c *proto.CompletedJob) { + r.mutex.Lock() + defer r.mutex.Unlock() + if r.completedJob == nil { + r.completedJob = c + r.cond.Signal() + } +} + +// setFail is an internal function to set the job to failed. This does not send the failedJob. +func (r *Runner) setFail(f *proto.FailedJob) { + r.mutex.Lock() + defer r.mutex.Unlock() + if r.failedJob == nil { + f.JobId = r.job.JobId + r.failedJob = f + r.cond.Signal() + } +} + +// ForceStop signals all goroutines to stop and prevents any further API calls back to coder server for this job +func (r *Runner) ForceStop() { + r.mutex.Lock() + defer r.mutex.Unlock() + r.stop() + if !r.okToSend { + return + } + r.okToSend = false + close(r.done) + // doesn't matter what we put here, since it won't get sent! Just need something to satisfy the condition in + // Start() + r.failedJob = &proto.FailedJob{} + r.cond.Signal() +} + +func (r *Runner) update(ctx context.Context, u *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + if !r.okToSend { + return nil, xerrors.New("update skipped; job complete or failed") + } + return r.sender.UpdateJob(ctx, u) +} + +// doCleanFinish wraps a call to do() with cleaning up the job and setting the terminal messages +func (r *Runner) doCleanFinish() { + // push the fail/succeed write onto the defer stack before the cleanup, so that cleanup happens + // before this. + var failedJob *proto.FailedJob + var completedJob *proto.CompletedJob + defer func() { + if failedJob != nil { + r.setFail(failedJob) + return + } + r.setComplete(completedJob) + }() + + defer func() { + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Cleaning Up", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + r.logger.Warn(r.notStopped, "failed to log cleanup") + return + } + + // Cleanup the work directory after execution. + for attempt := 0; attempt < 5; attempt++ { + err := r.filesystem.RemoveAll(r.workDirectory) + if err != nil { + // On Windows, open files cannot be removed. + // When the provisioner daemon is shutting down, + // it may take a few milliseconds for processes to exit. + // See: https://github.com/golang/go/issues/50510 + r.logger.Debug(r.notStopped, "failed to clean work directory; trying again", slog.Error(err)) + time.Sleep(250 * time.Millisecond) + continue + } + r.logger.Debug(r.notStopped, "cleaned up work directory", slog.Error(err)) + break + } + }() + + completedJob, failedJob = r.do() +} + +// do actually does the work of running the job +func (r *Runner) do() (*proto.CompletedJob, *proto.FailedJob) { + err := r.filesystem.MkdirAll(r.workDirectory, 0700) + if err != nil { + return nil, r.failedJobf("create work directory %q: %s", r.workDirectory, err) + } + + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Setting up", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + + r.logger.Info(r.notStopped, "unpacking template source archive", + slog.F("size_bytes", len(r.job.TemplateSourceArchive))) + reader := tar.NewReader(bytes.NewBuffer(r.job.TemplateSourceArchive)) + for { + header, err := reader.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, r.failedJobf("read template source archive: %s", err) + } + // #nosec + headerPath := filepath.Join(r.workDirectory, header.Name) + if !strings.HasPrefix(headerPath, filepath.Clean(r.workDirectory)) { + return nil, r.failedJobf("tar attempts to target relative upper directory") + } + mode := header.FileInfo().Mode() + if mode == 0 { + mode = 0600 + } + switch header.Typeflag { + case tar.TypeDir: + err = r.filesystem.MkdirAll(headerPath, mode) + if err != nil { + return nil, r.failedJobf("mkdir %q: %s", headerPath, err) + } + r.logger.Debug(context.Background(), "extracted directory", slog.F("path", headerPath)) + case tar.TypeReg: + file, err := r.filesystem.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode) + if err != nil { + return nil, r.failedJobf("create file %q (mode %s): %s", headerPath, mode, err) + } + // Max file size of 10MiB. + size, err := io.CopyN(file, reader, 10<<20) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + _ = file.Close() + return nil, r.failedJobf("copy file %q: %s", headerPath, err) + } + err = file.Close() + if err != nil { + return nil, r.failedJobf("close file %q: %s", headerPath, err) + } + r.logger.Debug(context.Background(), "extracted file", + slog.F("size_bytes", size), + slog.F("path", headerPath), + slog.F("mode", mode), + ) + } + } + + switch jobType := r.job.Type.(type) { + case *proto.AcquiredJob_TemplateImport_: + r.logger.Debug(context.Background(), "acquired job is template import") + + failedJob := r.runReadmeParse() + if failedJob != nil { + return nil, failedJob + } + return r.runTemplateImport() + case *proto.AcquiredJob_TemplateDryRun_: + r.logger.Debug(context.Background(), "acquired job is template dry-run", + slog.F("workspace_name", jobType.TemplateDryRun.Metadata.WorkspaceName), + slog.F("parameters", jobType.TemplateDryRun.ParameterValues), + ) + return r.runTemplateDryRun() + case *proto.AcquiredJob_WorkspaceBuild_: + r.logger.Debug(context.Background(), "acquired job is workspace provision", + slog.F("workspace_name", jobType.WorkspaceBuild.WorkspaceName), + slog.F("state_length", len(jobType.WorkspaceBuild.State)), + slog.F("parameters", jobType.WorkspaceBuild.ParameterValues), + ) + return r.runWorkspaceBuild() + default: + return nil, r.failedJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", + reflect.TypeOf(r.job.Type).String()) + } +} + +// heartbeat periodically sends updates on the job, which keeps coder server from assuming the job +// is stalled, and allows the runner to learn if the job has been canceled by the user. +func (r *Runner) heartbeat() { + ticker := time.NewTicker(r.updateInterval) + defer ticker.Stop() + for { + select { + case <-r.notCanceled.Done(): + return + case <-ticker.C: + } + + resp, err := r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + }) + if err != nil { + err = r.Fail(r.notStopped, r.failedJobf("send periodic update: %s", err)) + if err != nil { + r.logger.Error(r.notStopped, "failed to call FailJob", slog.Error(err)) + } + return + } + if !resp.Canceled { + continue + } + r.logger.Info(r.notStopped, "attempting graceful cancelation") + r.Cancel() + // Hard-cancel the job after a minute of pending cancelation. + timer := time.NewTimer(r.forceCancelInterval) + select { + case <-timer.C: + r.logger.Warn(r.notStopped, "Cancel timed out") + err := r.Fail(r.notStopped, r.failedJobf("Cancel timed out")) + if err != nil { + r.logger.Warn(r.notStopped, "failed to call FailJob", slog.Error(err)) + } + return + case <-r.Done(): + timer.Stop() + return + case <-r.notStopped.Done(): + timer.Stop() + return + } + } +} + +// ReadmeFile is the location we look for to extract documentation from template +// versions. +const ReadmeFile = "README.md" + +func (r *Runner) runReadmeParse() *proto.FailedJob { + fi, err := afero.ReadFile(r.filesystem, path.Join(r.workDirectory, ReadmeFile)) + if err != nil { + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_DEBUG, + Stage: "No README.md provided", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return r.failedJobf("write log: %s", err) + } + + return nil + } + + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Adding README.md...", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + Readme: fi, + }) + if err != nil { + return r.failedJobf("write log: %s", err) + } + return nil +} + +func (r *Runner) runTemplateImport() (*proto.CompletedJob, *proto.FailedJob) { + // Parse parameters and update the job with the parameter specs + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Parsing template parameters", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + parameterSchemas, err := r.runTemplateImportParse() + if err != nil { + return nil, r.failedJobf("run parse: %s", err) + } + updateResponse, err := r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + ParameterSchemas: parameterSchemas, + }) + if err != nil { + return nil, r.failedJobf("update job: %s", err) + } + + valueByName := map[string]*sdkproto.ParameterValue{} + for _, parameterValue := range updateResponse.ParameterValues { + valueByName[parameterValue.Name] = parameterValue + } + for _, parameterSchema := range parameterSchemas { + _, ok := valueByName[parameterSchema.Name] + if !ok { + return nil, r.failedJobf("%s: %s", MissingParameterErrorText, parameterSchema.Name) + } + } + + // Determine persistent resources + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Detecting persistent resources", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + startResources, err := r.runTemplateImportProvision(updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ + CoderUrl: r.job.GetTemplateImport().Metadata.CoderUrl, + WorkspaceTransition: sdkproto.WorkspaceTransition_START, + }) + if err != nil { + return nil, r.failedJobf("template import provision for start: %s", err) + } + + // Determine ephemeral resources. + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: "Detecting ephemeral resources", + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + stopResources, err := r.runTemplateImportProvision(updateResponse.ParameterValues, &sdkproto.Provision_Metadata{ + CoderUrl: r.job.GetTemplateImport().Metadata.CoderUrl, + WorkspaceTransition: sdkproto.WorkspaceTransition_STOP, + }) + if err != nil { + return nil, r.failedJobf("template import provision for stop: %s", err) + } + + return &proto.CompletedJob{ + JobId: r.job.JobId, + Type: &proto.CompletedJob_TemplateImport_{ + TemplateImport: &proto.CompletedJob_TemplateImport{ + StartResources: startResources, + StopResources: stopResources, + }, + }, + }, nil +} + +// Parses parameter schemas from source. +func (r *Runner) runTemplateImportParse() ([]*sdkproto.ParameterSchema, error) { + stream, err := r.provisioner.Parse(r.notStopped, &sdkproto.Parse_Request{ + Directory: r.workDirectory, + }) + if err != nil { + return nil, xerrors.Errorf("parse source: %w", err) + } + defer stream.Close() + for { + msg, err := stream.Recv() + if err != nil { + return nil, xerrors.Errorf("recv parse source: %w", err) + } + switch msgType := msg.Type.(type) { + case *sdkproto.Parse_Response_Log: + r.logger.Debug(context.Background(), "parse job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + ) + + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + Stage: "Parse parameters", + }}, + }) + if err != nil { + return nil, xerrors.Errorf("update job: %w", err) + } + case *sdkproto.Parse_Response_Complete: + r.logger.Info(context.Background(), "parse complete", + slog.F("parameter_schemas", msgType.Complete.ParameterSchemas)) + + return msgType.Complete.ParameterSchemas, nil + default: + return nil, xerrors.Errorf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String()) + } + } +} + +// Performs a dry-run provision when importing a template. +// This is used to detect resources that would be provisioned +// for a workspace in various states. +func (r *Runner) runTemplateImportProvision(values []*sdkproto.ParameterValue, metadata *sdkproto.Provision_Metadata) ([]*sdkproto.Resource, error) { + var stage string + switch metadata.WorkspaceTransition { + case sdkproto.WorkspaceTransition_START: + stage = "Detecting persistent resources" + case sdkproto.WorkspaceTransition_STOP: + stage = "Detecting ephemeral resources" + } + // use the notStopped so that if we attempt to gracefully cancel, the stream will still be available for us + // to send the cancel to the provisioner + stream, err := r.provisioner.Provision(r.notStopped) + if err != nil { + return nil, xerrors.Errorf("provision: %w", err) + } + defer stream.Close() + go func() { + select { + case <-r.notStopped.Done(): + return + case <-r.notCanceled.Done(): + _ = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Cancel{ + Cancel: &sdkproto.Provision_Cancel{}, + }, + }) + } + }() + err = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Start{ + Start: &sdkproto.Provision_Start{ + Directory: r.workDirectory, + ParameterValues: values, + DryRun: true, + Metadata: metadata, + }, + }, + }) + if err != nil { + return nil, xerrors.Errorf("start provision: %w", err) + } + + for { + msg, err := stream.Recv() + if err != nil { + return nil, xerrors.Errorf("recv import provision: %w", err) + } + switch msgType := msg.Type.(type) { + case *sdkproto.Provision_Response_Log: + r.logger.Debug(context.Background(), "template import provision job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + ) + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + Stage: stage, + }}, + }) + if err != nil { + return nil, xerrors.Errorf("send job update: %w", err) + } + case *sdkproto.Provision_Response_Complete: + if msgType.Complete.Error != "" { + r.logger.Info(context.Background(), "dry-run provision failure", + slog.F("error", msgType.Complete.Error), + ) + + return nil, xerrors.New(msgType.Complete.Error) + } + + r.logger.Info(context.Background(), "parse dry-run provision successful", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + + return msgType.Complete.Resources, nil + default: + return nil, xerrors.Errorf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String()) + } + } +} + +func (r *Runner) runTemplateDryRun() ( + *proto.CompletedJob, *proto.FailedJob) { + // Ensure all metadata fields are set as they are all optional for dry-run. + metadata := r.job.GetTemplateDryRun().GetMetadata() + metadata.WorkspaceTransition = sdkproto.WorkspaceTransition_START + if metadata.CoderUrl == "" { + metadata.CoderUrl = "http://localhost:3000" + } + if metadata.WorkspaceName == "" { + metadata.WorkspaceName = "dryrun" + } + metadata.WorkspaceOwner = r.job.UserName + if metadata.WorkspaceOwner == "" { + metadata.WorkspaceOwner = "dryrunner" + } + if metadata.WorkspaceId == "" { + id, err := uuid.NewRandom() + if err != nil { + return nil, r.failedJobf("generate random ID: %s", err) + } + metadata.WorkspaceId = id.String() + } + if metadata.WorkspaceOwnerId == "" { + id, err := uuid.NewRandom() + if err != nil { + return nil, r.failedJobf("generate random ID: %s", err) + } + metadata.WorkspaceOwnerId = id.String() + } + + // Run the template import provision task since it's already a dry run. + resources, err := r.runTemplateImportProvision( + r.job.GetTemplateDryRun().GetParameterValues(), + metadata, + ) + if err != nil { + return nil, r.failedJobf("run dry-run provision job: %s", err) + } + + return &proto.CompletedJob{ + JobId: r.job.JobId, + Type: &proto.CompletedJob_TemplateDryRun_{ + TemplateDryRun: &proto.CompletedJob_TemplateDryRun{ + Resources: resources, + }, + }, + }, nil +} + +func (r *Runner) runWorkspaceBuild() ( + *proto.CompletedJob, *proto.FailedJob) { + var stage string + switch r.job.GetWorkspaceBuild().Metadata.WorkspaceTransition { + case sdkproto.WorkspaceTransition_START: + stage = "Starting workspace" + case sdkproto.WorkspaceTransition_STOP: + stage = "Stopping workspace" + case sdkproto.WorkspaceTransition_DESTROY: + stage = "Destroying workspace" + } + + _, err := r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER_DAEMON, + Level: sdkproto.LogLevel_INFO, + Stage: stage, + CreatedAt: time.Now().UTC().UnixMilli(), + }}, + }) + if err != nil { + return nil, r.failedJobf("write log: %s", err) + } + + // use the notStopped so that if we attempt to gracefully cancel, the stream will still be available for us + // to send the cancel to the provisioner + stream, err := r.provisioner.Provision(r.notStopped) + if err != nil { + return nil, r.failedJobf("provision: %s", err) + } + defer stream.Close() + go func() { + select { + case <-r.notStopped.Done(): + return + case <-r.notCanceled.Done(): + _ = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Cancel{ + Cancel: &sdkproto.Provision_Cancel{}, + }, + }) + } + }() + err = stream.Send(&sdkproto.Provision_Request{ + Type: &sdkproto.Provision_Request_Start{ + Start: &sdkproto.Provision_Start{ + Directory: r.workDirectory, + ParameterValues: r.job.GetWorkspaceBuild().ParameterValues, + Metadata: r.job.GetWorkspaceBuild().Metadata, + State: r.job.GetWorkspaceBuild().State, + }, + }, + }) + if err != nil { + return nil, r.failedJobf("start provision: %s", err) + } + + for { + msg, err := stream.Recv() + if err != nil { + return nil, r.failedJobf("recv workspace provision: %s", err) + } + switch msgType := msg.Type.(type) { + case *sdkproto.Provision_Response_Log: + r.logger.Debug(context.Background(), "workspace provision job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("workspace_build_id", r.job.GetWorkspaceBuild().WorkspaceBuildId), + ) + + _, err = r.update(r.notStopped, &proto.UpdateJobRequest{ + JobId: r.job.JobId, + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + Stage: stage, + }}, + }) + if err != nil { + return nil, r.failedJobf("send job update: %s", err) + } + case *sdkproto.Provision_Response_Complete: + if msgType.Complete.Error != "" { + r.logger.Info(context.Background(), "provision failed; updating state", + slog.F("state_length", len(msgType.Complete.State)), + ) + + return nil, &proto.FailedJob{ + JobId: r.job.JobId, + Error: msgType.Complete.Error, + Type: &proto.FailedJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.FailedJob_WorkspaceBuild{ + State: msgType.Complete.State, + }, + }, + } + } + + r.logger.Debug(context.Background(), "provision complete no error") + r.logger.Info(context.Background(), "provision successful; marked job as complete", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + // Stop looping! + return &proto.CompletedJob{ + JobId: r.job.JobId, + Type: &proto.CompletedJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ + State: msgType.Complete.State, + Resources: msgType.Complete.Resources, + }, + }, + }, nil + default: + return nil, r.failedJobf("invalid message type %T received from provisioner", msg.Type) + } + } +} + +func (r *Runner) failedJobf(format string, args ...interface{}) *proto.FailedJob { + return &proto.FailedJob{ + JobId: r.job.JobId, + Error: fmt.Sprintf(format, args...), + } +} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index e465a35d06d24..178998b5a21f9 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -89,7 +89,7 @@ func (p *PTY) ExpectMatch(str string) string { case <-timer.C: } _ = p.Close() - p.t.Errorf("match exceeded deadline: wanted %q; got %q", str, buffer.String()) + p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) }() for { var r rune