diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 423e9bbe584c6..9c4067137b852 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -1340,14 +1340,56 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) switch jobType := completed.Type.(type) { case *proto.CompletedJob_TemplateImport_: - var input TemplateVersionImportJob - err = json.Unmarshal(job.Input, &input) + err = s.completeTemplateImportJob(ctx, job, jobID, jobType, telemetrySnapshot) + if err != nil { + return nil, err + } + case *proto.CompletedJob_WorkspaceBuild_: + err = s.completeWorkspaceBuildJob(ctx, job, jobID, jobType, telemetrySnapshot) + if err != nil { + return nil, err + } + case *proto.CompletedJob_TemplateDryRun_: + err = s.completeTemplateDryRunJob(ctx, job, jobID, jobType, telemetrySnapshot) if err != nil { - return nil, xerrors.Errorf("template version ID is expected: %w", err) + return nil, err + } + default: + if completed.Type == nil { + return nil, xerrors.Errorf("type payload must be provided") } + return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", + reflect.TypeOf(completed.Type).String()) + } + + data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) + if err != nil { + return nil, xerrors.Errorf("marshal job log: %w", err) + } + err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) + if err != nil { + s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + return nil, xerrors.Errorf("publish end of job logs: %w", err) + } + s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID)) + return &proto.Empty{}, nil +} + +// completeTemplateImportJob handles completion of a template import job. +// All database operations are performed within a transaction. +func (s *server) completeTemplateImportJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_TemplateImport_, telemetrySnapshot *telemetry.Snapshot) error { + var input TemplateVersionImportJob + err := json.Unmarshal(job.Input, &input) + if err != nil { + return xerrors.Errorf("template version ID is expected: %w", err) + } + + // Execute all database operations in a transaction + return s.Database.InTx(func(db database.Store) error { now := s.timeNow() + // Process resources for transition, resources := range map[database.WorkspaceTransition][]*sdkproto.Resource{ database.WorkspaceTransitionStart: jobType.TemplateImport.StartResources, database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources, @@ -1359,11 +1401,13 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) slog.F("resource_type", resource.Type), slog.F("transition", transition)) - if err := InsertWorkspaceResource(ctx, s.Database, jobID, transition, resource, telemetrySnapshot); err != nil { - return nil, xerrors.Errorf("insert resource: %w", err) + if err := InsertWorkspaceResource(ctx, db, jobID, transition, resource, telemetrySnapshot); err != nil { + return xerrors.Errorf("insert resource: %w", err) } } } + + // Process modules for transition, modules := range map[database.WorkspaceTransition][]*sdkproto.Module{ database.WorkspaceTransitionStart: jobType.TemplateImport.StartModules, database.WorkspaceTransitionStop: jobType.TemplateImport.StopModules, @@ -1376,12 +1420,13 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) slog.F("module_key", module.Key), slog.F("transition", transition)) - if err := InsertWorkspaceModule(ctx, s.Database, jobID, transition, module, telemetrySnapshot); err != nil { - return nil, xerrors.Errorf("insert module: %w", err) + if err := InsertWorkspaceModule(ctx, db, jobID, transition, module, telemetrySnapshot); err != nil { + return xerrors.Errorf("insert module: %w", err) } } } + // Process rich parameters for _, richParameter := range jobType.TemplateImport.RichParameters { s.Logger.Info(ctx, "inserting template import job parameter", slog.F("job_id", job.ID.String()), @@ -1391,7 +1436,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) ) options, err := json.Marshal(richParameter.Options) if err != nil { - return nil, xerrors.Errorf("marshal parameter options: %w", err) + return xerrors.Errorf("marshal parameter options: %w", err) } var validationMin, validationMax sql.NullInt32 @@ -1408,7 +1453,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) } } - _, err = s.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{ + _, err = db.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{ TemplateVersionID: input.TemplateVersionID, Name: richParameter.Name, DisplayName: richParameter.DisplayName, @@ -1428,15 +1473,17 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) Ephemeral: richParameter.Ephemeral, }) if err != nil { - return nil, xerrors.Errorf("insert parameter: %w", err) + return xerrors.Errorf("insert parameter: %w", err) } } - err = InsertWorkspacePresetsAndParameters(ctx, s.Logger, s.Database, jobID, input.TemplateVersionID, jobType.TemplateImport.Presets, now) + // Process presets and parameters + err := InsertWorkspacePresetsAndParameters(ctx, s.Logger, db, jobID, input.TemplateVersionID, jobType.TemplateImport.Presets, now) if err != nil { - return nil, xerrors.Errorf("insert workspace presets and parameters: %w", err) + return xerrors.Errorf("insert workspace presets and parameters: %w", err) } + // Process external auth providers var completedError sql.NullString for _, externalAuthProvider := range jobType.TemplateImport.ExternalAuthProviders { @@ -1479,18 +1526,19 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) externalAuthProvidersMessage, err := json.Marshal(externalAuthProviders) if err != nil { - return nil, xerrors.Errorf("failed to serialize external_auth_providers value: %w", err) + return xerrors.Errorf("failed to serialize external_auth_providers value: %w", err) } - err = s.Database.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{ + err = db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{ JobID: jobID, ExternalAuthProviders: externalAuthProvidersMessage, UpdatedAt: now, }) if err != nil { - return nil, xerrors.Errorf("update template version external auth providers: %w", err) + return xerrors.Errorf("update template version external auth providers: %w", err) } + // Process terraform values plan := jobType.TemplateImport.Plan moduleFiles := jobType.TemplateImport.ModuleFiles // If there is a plan, or a module files archive we need to insert a @@ -1509,7 +1557,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) hash := hex.EncodeToString(hashBytes[:]) // nolint:gocritic // Requires reading "system" files - file, err := s.Database.GetFileByHashAndCreator(dbauthz.AsSystemRestricted(ctx), database.GetFileByHashAndCreatorParams{Hash: hash, CreatedBy: uuid.Nil}) + file, err := db.GetFileByHashAndCreator(dbauthz.AsSystemRestricted(ctx), database.GetFileByHashAndCreatorParams{Hash: hash, CreatedBy: uuid.Nil}) switch { case err == nil: // This set of modules is already cached, which means we can reuse them @@ -1518,10 +1566,10 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) UUID: file.ID, } case !xerrors.Is(err, sql.ErrNoRows): - return nil, xerrors.Errorf("check for cached modules: %w", err) + return xerrors.Errorf("check for cached modules: %w", err) default: // nolint:gocritic // Requires creating a "system" file - file, err = s.Database.InsertFile(dbauthz.AsSystemRestricted(ctx), database.InsertFileParams{ + file, err = db.InsertFile(dbauthz.AsSystemRestricted(ctx), database.InsertFileParams{ ID: uuid.New(), Hash: hash, CreatedBy: uuid.Nil, @@ -1530,7 +1578,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) Data: moduleFiles, }) if err != nil { - return nil, xerrors.Errorf("insert template version terraform modules: %w", err) + return xerrors.Errorf("insert template version terraform modules: %w", err) } fileID = uuid.NullUUID{ Valid: true, @@ -1539,7 +1587,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) } } - err = s.Database.InsertTemplateVersionTerraformValuesByJobID(ctx, database.InsertTemplateVersionTerraformValuesByJobIDParams{ + err = db.InsertTemplateVersionTerraformValuesByJobID(ctx, database.InsertTemplateVersionTerraformValuesByJobIDParams{ JobID: jobID, UpdatedAt: now, CachedPlan: plan, @@ -1547,11 +1595,12 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) ProvisionerdVersion: s.apiVersion, }) if err != nil { - return nil, xerrors.Errorf("insert template version terraform data: %w", err) + return xerrors.Errorf("insert template version terraform data: %w", err) } } - err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + // Mark job as completed + err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: now, CompletedAt: sql.NullTime{ @@ -1562,206 +1611,136 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) ErrorCode: sql.NullString{}, }) if err != nil { - return nil, xerrors.Errorf("update provisioner job: %w", err) + return xerrors.Errorf("update provisioner job: %w", err) } s.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) - case *proto.CompletedJob_WorkspaceBuild_: - var input WorkspaceProvisionJob - err = json.Unmarshal(job.Input, &input) - if err != nil { - return nil, xerrors.Errorf("unmarshal job data: %w", err) - } + return nil + }, nil) // End of transaction +} - workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) - if err != nil { - return nil, xerrors.Errorf("get workspace build: %w", err) - } +// completeWorkspaceBuildJob handles completion of a workspace build job. +// Most database operations are performed within a transaction. +func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_WorkspaceBuild_, telemetrySnapshot *telemetry.Snapshot) error { + var input WorkspaceProvisionJob + err := json.Unmarshal(job.Input, &input) + if err != nil { + return xerrors.Errorf("unmarshal job data: %w", err) + } - var workspace database.Workspace - var getWorkspaceError error + workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) + if err != nil { + return xerrors.Errorf("get workspace build: %w", err) + } - err = s.Database.InTx(func(db database.Store) error { - // It's important we use s.timeNow() here because we want to be - // able to customize the current time from within tests. - now := s.timeNow() - - workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) - if getWorkspaceError != nil { - s.Logger.Error(ctx, - "fetch workspace for build", - slog.F("workspace_build_id", workspaceBuild.ID), - slog.F("workspace_id", workspaceBuild.WorkspaceID), - ) - return getWorkspaceError - } + var workspace database.Workspace + var getWorkspaceError error - templateScheduleStore := *s.TemplateScheduleStore.Load() + // Execute all database modifications in a transaction + err = s.Database.InTx(func(db database.Store) error { + // It's important we use s.timeNow() here because we want to be + // able to customize the current time from within tests. + now := s.timeNow() - autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{ - Database: db, - TemplateScheduleStore: templateScheduleStore, - UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(), - Now: now, - Workspace: workspace.WorkspaceTable(), - // Allowed to be the empty string. - WorkspaceAutostart: workspace.AutostartSchedule.String, - }) - if err != nil { - return xerrors.Errorf("calculate auto stop: %w", err) - } + workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) + if getWorkspaceError != nil { + s.Logger.Error(ctx, + "fetch workspace for build", + slog.F("workspace_build_id", workspaceBuild.ID), + slog.F("workspace_id", workspaceBuild.WorkspaceID), + ) + return getWorkspaceError + } - if workspace.AutostartSchedule.Valid { - templateScheduleOptions, err := templateScheduleStore.Get(ctx, db, workspace.TemplateID) - if err != nil { - return xerrors.Errorf("get template schedule options: %w", err) - } + templateScheduleStore := *s.TemplateScheduleStore.Load() - nextStartAt, err := schedule.NextAllowedAutostart(now, workspace.AutostartSchedule.String, templateScheduleOptions) - if err == nil { - err = db.UpdateWorkspaceNextStartAt(ctx, database.UpdateWorkspaceNextStartAtParams{ - ID: workspace.ID, - NextStartAt: sql.NullTime{Valid: true, Time: nextStartAt.UTC()}, - }) - if err != nil { - return xerrors.Errorf("update workspace next start at: %w", err) - } - } - } + autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{ + Database: db, + TemplateScheduleStore: templateScheduleStore, + UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(), + Now: now, + Workspace: workspace.WorkspaceTable(), + // Allowed to be the empty string. + WorkspaceAutostart: workspace.AutostartSchedule.String, + }) + if err != nil { + return xerrors.Errorf("calculate auto stop: %w", err) + } - err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: jobID, - UpdatedAt: now, - CompletedAt: sql.NullTime{ - Time: now, - Valid: true, - }, - Error: sql.NullString{}, - ErrorCode: sql.NullString{}, - }) + if workspace.AutostartSchedule.Valid { + templateScheduleOptions, err := templateScheduleStore.Get(ctx, db, workspace.TemplateID) if err != nil { - return xerrors.Errorf("update provisioner job: %w", err) + return xerrors.Errorf("get template schedule options: %w", err) } - err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{ - ID: workspaceBuild.ID, - ProvisionerState: jobType.WorkspaceBuild.State, - UpdatedAt: now, - }) - if err != nil { - return xerrors.Errorf("update workspace build provisioner state: %w", err) - } - err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ - ID: workspaceBuild.ID, - Deadline: autoStop.Deadline, - MaxDeadline: autoStop.MaxDeadline, - UpdatedAt: now, - }) - if err != nil { - return xerrors.Errorf("update workspace build deadline: %w", err) - } - - agentTimeouts := make(map[time.Duration]bool) // A set of agent timeouts. - // This could be a bulk insert to improve performance. - for _, protoResource := range jobType.WorkspaceBuild.Resources { - for _, protoAgent := range protoResource.Agents { - dur := time.Duration(protoAgent.GetConnectionTimeoutSeconds()) * time.Second - agentTimeouts[dur] = true - } - err = InsertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot) + nextStartAt, err := schedule.NextAllowedAutostart(now, workspace.AutostartSchedule.String, templateScheduleOptions) + if err == nil { + err = db.UpdateWorkspaceNextStartAt(ctx, database.UpdateWorkspaceNextStartAtParams{ + ID: workspace.ID, + NextStartAt: sql.NullTime{Valid: true, Time: nextStartAt.UTC()}, + }) if err != nil { - return xerrors.Errorf("insert provisioner job: %w", err) - } - } - for _, module := range jobType.WorkspaceBuild.Modules { - if err := InsertWorkspaceModule(ctx, db, job.ID, workspaceBuild.Transition, module, telemetrySnapshot); err != nil { - return xerrors.Errorf("insert provisioner job module: %w", err) + return xerrors.Errorf("update workspace next start at: %w", err) } } + } - // On start, we want to ensure that workspace agents timeout statuses - // are propagated. This method is simple and does not protect against - // notifying in edge cases like when a workspace is stopped soon - // after being started. - // - // Agent timeouts could be minutes apart, resulting in an unresponsive - // experience, so we'll notify after every unique timeout seconds. - if !input.DryRun && workspaceBuild.Transition == database.WorkspaceTransitionStart && len(agentTimeouts) > 0 { - timeouts := maps.Keys(agentTimeouts) - slices.Sort(timeouts) - - var updates []<-chan time.Time - for _, d := range timeouts { - s.Logger.Debug(ctx, "triggering workspace notification after agent timeout", - slog.F("workspace_build_id", workspaceBuild.ID), - slog.F("timeout", d), - ) - // Agents are inserted with `dbtime.Now()`, this triggers a - // workspace event approximately after created + timeout seconds. - updates = append(updates, time.After(d)) - } - go func() { - for _, wait := range updates { - select { - case <-s.lifecycleCtx.Done(): - // If the server is shutting down, we don't want to wait around. - s.Logger.Debug(ctx, "stopping notifications due to server shutdown", - slog.F("workspace_build_id", workspaceBuild.ID), - ) - return - case <-wait: - // Wait for the next potential timeout to occur. - msg, err := json.Marshal(wspubsub.WorkspaceEvent{ - Kind: wspubsub.WorkspaceEventKindAgentTimeout, - WorkspaceID: workspace.ID, - }) - if err != nil { - s.Logger.Error(ctx, "marshal workspace update event", slog.Error(err)) - break - } - if err := s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg); err != nil { - if s.lifecycleCtx.Err() != nil { - // If the server is shutting down, we don't want to log this error, nor wait around. - s.Logger.Debug(ctx, "stopping notifications due to server shutdown", - slog.F("workspace_build_id", workspaceBuild.ID), - ) - return - } - s.Logger.Error(ctx, "workspace notification after agent timeout failed", - slog.F("workspace_build_id", workspaceBuild.ID), - slog.Error(err), - ) - } - } - } - }() - } + err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: jobID, + UpdatedAt: now, + CompletedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + Error: sql.NullString{}, + ErrorCode: sql.NullString{}, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{ + ID: workspaceBuild.ID, + ProvisionerState: jobType.WorkspaceBuild.State, + UpdatedAt: now, + }) + if err != nil { + return xerrors.Errorf("update workspace build provisioner state: %w", err) + } + err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ + ID: workspaceBuild.ID, + Deadline: autoStop.Deadline, + MaxDeadline: autoStop.MaxDeadline, + UpdatedAt: now, + }) + if err != nil { + return xerrors.Errorf("update workspace build deadline: %w", err) + } - if workspaceBuild.Transition != database.WorkspaceTransitionDelete { - // This is for deleting a workspace! - return nil + agentTimeouts := make(map[time.Duration]bool) // A set of agent timeouts. + // This could be a bulk insert to improve performance. + for _, protoResource := range jobType.WorkspaceBuild.Resources { + for _, protoAgent := range protoResource.Agents { + dur := time.Duration(protoAgent.GetConnectionTimeoutSeconds()) * time.Second + agentTimeouts[dur] = true } - err = db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ - ID: workspaceBuild.WorkspaceID, - Deleted: true, - }) + err = InsertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot) if err != nil { - return xerrors.Errorf("update workspace deleted: %w", err) + return xerrors.Errorf("insert provisioner job: %w", err) + } + } + for _, module := range jobType.WorkspaceBuild.Modules { + if err := InsertWorkspaceModule(ctx, db, job.ID, workspaceBuild.Transition, module, telemetrySnapshot); err != nil { + return xerrors.Errorf("insert provisioner job module: %w", err) } - - return nil - }, nil) - if err != nil { - return nil, xerrors.Errorf("complete job: %w", err) } - // Insert timings outside transaction since it is metadata. + // Insert timings inside the transaction now // nolint:exhaustruct // The other fields are set further down. params := database.InsertProvisionerJobTimingsParams{ JobID: jobID, } - for _, t := range completed.GetWorkspaceBuild().GetTimings() { + for _, t := range jobType.WorkspaceBuild.Timings { if t.Start == nil || t.End == nil { s.Logger.Warn(ctx, "timings entry has nil start or end time", slog.F("entry", t.String())) continue @@ -1780,153 +1759,229 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) params.StartedAt = append(params.StartedAt, t.Start.AsTime()) params.EndedAt = append(params.EndedAt, t.End.AsTime()) } - _, err = s.Database.InsertProvisionerJobTimings(ctx, params) + _, err = db.InsertProvisionerJobTimings(ctx, params) if err != nil { - // Don't fail the transaction for non-critical data. + // Log error but don't fail the whole transaction for non-critical data s.Logger.Warn(ctx, "failed to update provisioner job timings", slog.F("job_id", jobID), slog.Error(err)) } - // audit the outcome of the workspace build - if getWorkspaceError == nil { - // If the workspace has been deleted, notify the owner about it. - if workspaceBuild.Transition == database.WorkspaceTransitionDelete { - s.notifyWorkspaceDeleted(ctx, workspace, workspaceBuild) - } + // On start, we want to ensure that workspace agents timeout statuses + // are propagated. This method is simple and does not protect against + // notifying in edge cases like when a workspace is stopped soon + // after being started. + // + // Agent timeouts could be minutes apart, resulting in an unresponsive + // experience, so we'll notify after every unique timeout seconds. + if !input.DryRun && workspaceBuild.Transition == database.WorkspaceTransitionStart && len(agentTimeouts) > 0 { + timeouts := maps.Keys(agentTimeouts) + slices.Sort(timeouts) + + var updates []<-chan time.Time + for _, d := range timeouts { + s.Logger.Debug(ctx, "triggering workspace notification after agent timeout", + slog.F("workspace_build_id", workspaceBuild.ID), + slog.F("timeout", d), + ) + // Agents are inserted with `dbtime.Now()`, this triggers a + // workspace event approximately after created + timeout seconds. + updates = append(updates, time.After(d)) + } + go func() { + for _, wait := range updates { + select { + case <-s.lifecycleCtx.Done(): + // If the server is shutting down, we don't want to wait around. + s.Logger.Debug(ctx, "stopping notifications due to server shutdown", + slog.F("workspace_build_id", workspaceBuild.ID), + ) + return + case <-wait: + // Wait for the next potential timeout to occur. + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentTimeout, + WorkspaceID: workspace.ID, + }) + if err != nil { + s.Logger.Error(ctx, "marshal workspace update event", slog.Error(err)) + break + } + if err := s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg); err != nil { + if s.lifecycleCtx.Err() != nil { + // If the server is shutting down, we don't want to log this error, nor wait around. + s.Logger.Debug(ctx, "stopping notifications due to server shutdown", + slog.F("workspace_build_id", workspaceBuild.ID), + ) + return + } + s.Logger.Error(ctx, "workspace notification after agent timeout failed", + slog.F("workspace_build_id", workspaceBuild.ID), + slog.Error(err), + ) + } + } + } + }() + } - auditor := s.Auditor.Load() - auditAction := auditActionFromTransition(workspaceBuild.Transition) + if workspaceBuild.Transition != database.WorkspaceTransitionDelete { + // This is for deleting a workspace! + return nil + } - previousBuildNumber := workspaceBuild.BuildNumber - 1 - previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ - WorkspaceID: workspace.ID, - BuildNumber: previousBuildNumber, - }) - if prevBuildErr != nil { - previousBuild = database.WorkspaceBuild{} - } + err = db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ + ID: workspaceBuild.WorkspaceID, + Deleted: true, + }) + if err != nil { + return xerrors.Errorf("update workspace deleted: %w", err) + } - // We pass the below information to the Auditor so that it - // can form a friendly string for the user to view in the UI. - buildResourceInfo := audit.AdditionalFields{ - WorkspaceName: workspace.Name, - BuildNumber: strconv.FormatInt(int64(workspaceBuild.BuildNumber), 10), - BuildReason: database.BuildReason(string(workspaceBuild.Reason)), - WorkspaceID: workspace.ID, - } + return nil + }, nil) + if err != nil { + return xerrors.Errorf("complete job: %w", err) + } - wriBytes, err := json.Marshal(buildResourceInfo) - if err != nil { - s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err)) - } - - bag := audit.BaggageFromContext(ctx) - - audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceBuild]{ - Audit: *auditor, - Log: s.Logger, - UserID: job.InitiatorID, - OrganizationID: workspace.OrganizationID, - RequestID: job.ID, - IP: bag.IP, - Action: auditAction, - Old: previousBuild, - New: workspaceBuild, - Status: http.StatusOK, - AdditionalFields: wriBytes, - }) - } + // Post-transaction operations (operations that do not require transactions or + // are external to the database, like audit logging, notifications, etc.) - if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { - // Track resource replacements, if there are any. - orchestrator := s.PrebuildsOrchestrator.Load() - if resourceReplacements := completed.GetWorkspaceBuild().GetResourceReplacements(); orchestrator != nil && len(resourceReplacements) > 0 { - // Fire and forget. Bind to the lifecycle of the server so shutdowns are handled gracefully. - go (*orchestrator).TrackResourceReplacement(s.lifecycleCtx, workspace.ID, workspaceBuild.ID, resourceReplacements) - } + // audit the outcome of the workspace build + if getWorkspaceError == nil { + // If the workspace has been deleted, notify the owner about it. + if workspaceBuild.Transition == database.WorkspaceTransitionDelete { + s.notifyWorkspaceDeleted(ctx, workspace, workspaceBuild) } - msg, err := json.Marshal(wspubsub.WorkspaceEvent{ - Kind: wspubsub.WorkspaceEventKindStateChange, + auditor := s.Auditor.Load() + auditAction := auditActionFromTransition(workspaceBuild.Transition) + + previousBuildNumber := workspaceBuild.BuildNumber - 1 + previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspace.ID, + BuildNumber: previousBuildNumber, }) - if err != nil { - return nil, xerrors.Errorf("marshal workspace update event: %s", err) + if prevBuildErr != nil { + previousBuild = database.WorkspaceBuild{} } - err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) + + // We pass the below information to the Auditor so that it + // can form a friendly string for the user to view in the UI. + buildResourceInfo := audit.AdditionalFields{ + WorkspaceName: workspace.Name, + BuildNumber: strconv.FormatInt(int64(workspaceBuild.BuildNumber), 10), + BuildReason: database.BuildReason(string(workspaceBuild.Reason)), + WorkspaceID: workspace.ID, + } + + wriBytes, err := json.Marshal(buildResourceInfo) if err != nil { - return nil, xerrors.Errorf("update workspace: %w", err) + s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err)) + } + + bag := audit.BaggageFromContext(ctx) + + audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceBuild]{ + Audit: *auditor, + Log: s.Logger, + UserID: job.InitiatorID, + OrganizationID: workspace.OrganizationID, + RequestID: job.ID, + IP: bag.IP, + Action: auditAction, + Old: previousBuild, + New: workspaceBuild, + Status: http.StatusOK, + AdditionalFields: wriBytes, + }) + } + + if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { + // Track resource replacements, if there are any. + orchestrator := s.PrebuildsOrchestrator.Load() + if resourceReplacements := jobType.WorkspaceBuild.ResourceReplacements; orchestrator != nil && len(resourceReplacements) > 0 { + // Fire and forget. Bind to the lifecycle of the server so shutdowns are handled gracefully. + go (*orchestrator).TrackResourceReplacement(s.lifecycleCtx, workspace.ID, workspaceBuild.ID, resourceReplacements) } + } - if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { - s.Logger.Info(ctx, "workspace prebuild successfully claimed by user", - slog.F("workspace_id", workspace.ID)) + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: workspace.ID, + }) + if err != nil { + return xerrors.Errorf("marshal workspace update event: %s", err) + } + err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) + if err != nil { + return xerrors.Errorf("update workspace: %w", err) + } - err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{ - WorkspaceID: workspace.ID, - Reason: agentsdk.ReinitializeReasonPrebuildClaimed, - }) - if err != nil { - s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err)) - } + if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { + s.Logger.Info(ctx, "workspace prebuild successfully claimed by user", + slog.F("workspace_id", workspace.ID)) + + err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{ + WorkspaceID: workspace.ID, + Reason: agentsdk.ReinitializeReasonPrebuildClaimed, + }) + if err != nil { + s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err)) } - case *proto.CompletedJob_TemplateDryRun_: + } + + return nil +} + +// completeTemplateDryRunJob handles completion of a template dry-run job. +// All database operations are performed within a transaction. +func (s *server) completeTemplateDryRunJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_TemplateDryRun_, telemetrySnapshot *telemetry.Snapshot) error { + // Execute all database operations in a transaction + return s.Database.InTx(func(db database.Store) error { + now := s.timeNow() + + // Process resources for _, resource := range jobType.TemplateDryRun.Resources { s.Logger.Info(ctx, "inserting template dry-run job resource", slog.F("job_id", job.ID.String()), slog.F("resource_name", resource.Name), slog.F("resource_type", resource.Type)) - err = InsertWorkspaceResource(ctx, s.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) + err := InsertWorkspaceResource(ctx, db, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) if err != nil { - return nil, xerrors.Errorf("insert resource: %w", err) + return xerrors.Errorf("insert resource: %w", err) } } + + // Process modules for _, module := range jobType.TemplateDryRun.Modules { s.Logger.Info(ctx, "inserting template dry-run job module", slog.F("job_id", job.ID.String()), slog.F("module_source", module.Source), ) - if err := InsertWorkspaceModule(ctx, s.Database, jobID, database.WorkspaceTransitionStart, module, telemetrySnapshot); err != nil { - return nil, xerrors.Errorf("insert module: %w", err) + if err := InsertWorkspaceModule(ctx, db, jobID, database.WorkspaceTransitionStart, module, telemetrySnapshot); err != nil { + return xerrors.Errorf("insert module: %w", err) } } - err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + // Mark job as complete + err := db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, - UpdatedAt: s.timeNow(), + UpdatedAt: now, CompletedAt: sql.NullTime{ - Time: s.timeNow(), + Time: now, Valid: true, }, Error: sql.NullString{}, ErrorCode: sql.NullString{}, }) if err != nil { - return nil, xerrors.Errorf("update provisioner job: %w", err) + return xerrors.Errorf("update provisioner job: %w", err) } s.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID)) - default: - if completed.Type == nil { - return nil, xerrors.Errorf("type payload must be provided") - } - return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", - reflect.TypeOf(completed.Type).String()) - } - - data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) - if err != nil { - return nil, xerrors.Errorf("marshal job log: %w", err) - } - err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) - if err != nil { - s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) - return nil, xerrors.Errorf("publish end of job logs: %w", err) - } - - s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID)) - return &proto.Empty{}, nil + return nil + }, nil) // End of transaction } func (s *server) notifyWorkspaceDeleted(ctx context.Context, workspace database.Workspace, build database.WorkspaceBuild) { diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index e125db348e701..eb63d84b1df1b 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -20,6 +20,7 @@ import ( "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/timestamppb" "storj.io/drpc" "cdr.dev/slog/sloggers/slogtest" @@ -1119,6 +1120,227 @@ func TestCompleteJob(t *testing.T) { require.ErrorContains(t, err, "you don't own this job") }) + // Test for verifying transaction behavior on the extracted methods + t.Run("TransactionBehavior", func(t *testing.T) { + t.Parallel() + // Test TemplateImport transaction + t.Run("TemplateImportTransaction", func(t *testing.T) { + t.Parallel() + srv, db, _, pd := setup(t, false, &overrides{}) + jobID := uuid.New() + versionID := uuid.New() + err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + ID: versionID, + JobID: jobID, + OrganizationID: pd.OrganizationID, + }) + require.NoError(t, err) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + OrganizationID: pd.OrganizationID, + ID: jobID, + Provisioner: database.ProvisionerTypeEcho, + Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + require.NoError(t, err) + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + OrganizationID: pd.OrganizationID, + WorkerID: uuid.NullUUID{ + UUID: pd.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + Type: &proto.CompletedJob_TemplateImport_{ + TemplateImport: &proto.CompletedJob_TemplateImport{ + StartResources: []*sdkproto.Resource{{ + Name: "test-resource", + Type: "aws_instance", + }}, + Plan: []byte("{}"), + }, + }, + }) + require.NoError(t, err) + + // Verify job was marked as completed + completedJob, err := db.GetProvisionerJobByID(ctx, job.ID) + require.NoError(t, err) + require.True(t, completedJob.CompletedAt.Valid, "Job should be marked as completed") + + // Verify resources were created + resources, err := db.GetWorkspaceResourcesByJobID(ctx, job.ID) + require.NoError(t, err) + require.Len(t, resources, 1, "Expected one resource to be created") + require.Equal(t, "test-resource", resources[0].Name) + }) + + // Test TemplateDryRun transaction + t.Run("TemplateDryRunTransaction", func(t *testing.T) { + t.Parallel() + srv, db, _, pd := setup(t, false, &overrides{}) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + StorageMethod: database.ProvisionerStorageMethodFile, + }) + require.NoError(t, err) + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: pd.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + Type: &proto.CompletedJob_TemplateDryRun_{ + TemplateDryRun: &proto.CompletedJob_TemplateDryRun{ + Resources: []*sdkproto.Resource{{ + Name: "test-dry-run-resource", + Type: "aws_instance", + }}, + }, + }, + }) + require.NoError(t, err) + + // Verify job was marked as completed + completedJob, err := db.GetProvisionerJobByID(ctx, job.ID) + require.NoError(t, err) + require.True(t, completedJob.CompletedAt.Valid, "Job should be marked as completed") + + // Verify resources were created + resources, err := db.GetWorkspaceResourcesByJobID(ctx, job.ID) + require.NoError(t, err) + require.Len(t, resources, 1, "Expected one resource to be created") + require.Equal(t, "test-dry-run-resource", resources[0].Name) + }) + + // Test WorkspaceBuild transaction + t.Run("WorkspaceBuildTransaction", func(t *testing.T) { + t.Parallel() + srv, db, ps, pd := setup(t, false, &overrides{}) + + // Create test data + user := dbgen.User(t, db, database.User{}) + template := dbgen.Template(t, db, database.Template{ + Name: "template", + Provisioner: database.ProvisionerTypeEcho, + OrganizationID: pd.OrganizationID, + }) + file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) + workspaceTable := dbgen.Workspace(t, db, database.WorkspaceTable{ + TemplateID: template.ID, + OwnerID: user.ID, + OrganizationID: pd.OrganizationID, + }) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: pd.OrganizationID, + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + JobID: uuid.New(), + }) + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspaceTable.ID, + TemplateVersionID: version.ID, + Transition: database.WorkspaceTransitionStart, + Reason: database.BuildReasonInitiator, + }) + job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ + FileID: file.ID, + InitiatorID: user.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ + WorkspaceBuildID: build.ID, + })), + OrganizationID: pd.OrganizationID, + }) + _, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + OrganizationID: pd.OrganizationID, + WorkerID: uuid.NullUUID{ + UUID: pd.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + + // Add a published channel to make sure the workspace event is sent + publishedWorkspace := make(chan struct{}) + closeWorkspaceSubscribe, err := ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspaceTable.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStateChange && e.WorkspaceID == workspaceTable.ID { + close(publishedWorkspace) + } + })) + require.NoError(t, err) + defer closeWorkspaceSubscribe() + + // The actual test + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + Type: &proto.CompletedJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ + State: []byte{}, + Resources: []*sdkproto.Resource{{ + Name: "test-workspace-resource", + Type: "aws_instance", + }}, + Timings: []*sdkproto.Timing{{ + Stage: "test", + Source: "test-source", + Resource: "test-resource", + Action: "test-action", + Start: timestamppb.Now(), + End: timestamppb.Now(), + }}, + }, + }, + }) + require.NoError(t, err) + + // Wait for workspace notification + select { + case <-publishedWorkspace: + // Success + case <-time.After(testutil.WaitShort): + t.Fatal("Workspace event not published") + } + + // Verify job was marked as completed + completedJob, err := db.GetProvisionerJobByID(ctx, job.ID) + require.NoError(t, err) + require.True(t, completedJob.CompletedAt.Valid, "Job should be marked as completed") + + // Verify resources were created + resources, err := db.GetWorkspaceResourcesByJobID(ctx, job.ID) + require.NoError(t, err) + require.Len(t, resources, 1, "Expected one resource to be created") + require.Equal(t, "test-workspace-resource", resources[0].Name) + + // Verify timings were recorded + timings, err := db.GetProvisionerJobTimingsByJobID(ctx, job.ID) + require.NoError(t, err) + require.Len(t, timings, 1, "Expected one timing entry to be created") + require.Equal(t, "test", string(timings[0].Stage), "Timing stage should match what was sent") + }) + }) + t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) { t.Parallel() srv, db, _, pd := setup(t, false, &overrides{})