From ac192e79fe2ef837cb57dbf1f64244ca6f803451 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 29 Aug 2023 11:18:42 +0000 Subject: [PATCH 1/4] fix: fix null pointer on external provisioner daemons with daily_cost Signed-off-by: Spike Curtis --- coderd/coderd.go | 40 +- .../provisionerdserver/provisionerdserver.go | 313 ++++++----- .../provisionerdserver_test.go | 497 +++++++++++------- enterprise/coderd/provisionerdaemons.go | 40 +- enterprise/coderd/provisionerdaemons_test.go | 108 ++++ 5 files changed, 649 insertions(+), 349 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 0338a020eae36..27d566753909a 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1099,25 +1099,27 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti mux := drpcmux.New() - err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ - AccessURL: api.AccessURL, - ID: daemon.ID, - OIDCConfig: api.OIDCConfig, - Database: api.Database, - Pubsub: api.Pubsub, - Provisioners: daemon.Provisioners, - GitAuthConfigs: api.GitAuthConfigs, - Telemetry: api.Telemetry, - Tracer: tracer, - Tags: tags, - QuotaCommitter: &api.QuotaCommitter, - Auditor: &api.Auditor, - TemplateScheduleStore: api.TemplateScheduleStore, - UserQuietHoursScheduleStore: api.UserQuietHoursScheduleStore, - AcquireJobDebounce: debounce, - Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), - DeploymentValues: api.DeploymentValues, - }) + err = proto.DRPCRegisterProvisionerDaemon(mux, provisionerdserver.NewServer( + api.AccessURL, + daemon.ID, + api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), + daemon.Provisioners, + tags, + api.Database, + api.Pubsub, + api.Telemetry, + tracer, + &api.QuotaCommitter, + &api.Auditor, + api.TemplateScheduleStore, + api.UserQuietHoursScheduleStore, + api.DeploymentValues, + debounce, + provisionerdserver.Options{ + OIDCConfig: api.OIDCConfig, + GitAuthConfigs: api.GitAuthConfigs, + }, + )) if err != nil { return nil, err } diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 695892c86c9cc..f07513d83ad38 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -48,7 +48,14 @@ var ( lastAcquireMutex sync.RWMutex ) -type Server struct { +type Options struct { + OIDCConfig httpmw.OAuth2Config + GitAuthConfigs []*gitauth.Config + // TimeNowFn is only used in tests + TimeNowFn func() time.Time +} + +type server struct { AccessURL *url.URL ID uuid.UUID Logger slog.Logger @@ -71,17 +78,73 @@ type Server struct { TimeNowFn func() time.Time } +func NewServer( + accessURL *url.URL, + id uuid.UUID, + logger slog.Logger, + provisioners []database.ProvisionerType, + tags json.RawMessage, + db database.Store, + ps pubsub.Pubsub, + tel telemetry.Reporter, + tracer trace.Tracer, + quotaCommitter *atomic.Pointer[proto.QuotaCommitter], + auditor *atomic.Pointer[audit.Auditor], + templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore], + userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore], + deploymentValues *codersdk.DeploymentValues, + acquireJobDebounce time.Duration, + options Options, +) proto.DRPCProvisionerDaemonServer { + // Panic early if pointers are nil + if quotaCommitter == nil { + panic("quotaCommitter is nil") + } + if auditor == nil { + panic("auditor is nil") + } + if templateScheduleStore == nil { + panic("templateScheduleStore is nil") + } + if userQuietHoursScheduleStore == nil { + panic("userQuietHoursScheduleStore is nil") + } + if deploymentValues == nil { + panic("deploymentValues is nil") + } + return &server{ + AccessURL: accessURL, + ID: id, + Logger: logger, + Provisioners: provisioners, + GitAuthConfigs: options.GitAuthConfigs, + Tags: tags, + Database: db, + Pubsub: ps, + Telemetry: tel, + Tracer: tracer, + QuotaCommitter: quotaCommitter, + Auditor: auditor, + TemplateScheduleStore: templateScheduleStore, + UserQuietHoursScheduleStore: userQuietHoursScheduleStore, + DeploymentValues: deploymentValues, + AcquireJobDebounce: acquireJobDebounce, + OIDCConfig: options.OIDCConfig, + TimeNowFn: options.TimeNowFn, + } +} + // timeNow should be used when trying to get the current time for math // calculations regarding workspace start and stop time. -func (server *Server) timeNow() time.Time { - if server.TimeNowFn != nil { - return database.Time(server.TimeNowFn()) +func (s *server) timeNow() time.Time { + if s.TimeNowFn != nil { + return database.Time(s.TimeNowFn()) } return database.Now() } // AcquireJob queries the database to lock a job. -func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { +func (s *server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { //nolint:gocritic // Provisionerd has specific authz rules. ctx = dbauthz.AsProvisionerd(ctx) // This prevents loads of provisioner daemons from consistently @@ -90,23 +153,23 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac // The debounce only occurs when no job is returned, so if loads of // jobs are added at once, they will start after at most this duration. lastAcquireMutex.RLock() - if !lastAcquire.IsZero() && time.Since(lastAcquire) < server.AcquireJobDebounce { + if !lastAcquire.IsZero() && time.Since(lastAcquire) < s.AcquireJobDebounce { lastAcquireMutex.RUnlock() return &proto.AcquiredJob{}, nil } lastAcquireMutex.RUnlock() // This marks the job as locked in the database. - job, err := server.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + job, err := s.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ StartedAt: sql.NullTime{ Time: database.Now(), Valid: true, }, WorkerID: uuid.NullUUID{ - UUID: server.ID, + UUID: s.ID, Valid: true, }, - Types: server.Provisioners, - Tags: server.Tags, + Types: s.Provisioners, + Tags: s.Tags, }) if errors.Is(err, sql.ErrNoRows) { // The provisioner daemon assumes no jobs are available if @@ -119,11 +182,11 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac if err != nil { return nil, xerrors.Errorf("acquire job: %w", err) } - server.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID)) + s.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID)) // Marks the acquired job as failed with the error message provided. failJob := func(errorMessage string) error { - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: job.ID, CompletedAt: sql.NullTime{ Time: database.Now(), @@ -141,7 +204,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac return xerrors.Errorf("request job was invalidated: %s", errorMessage) } - user, err := server.Database.GetUserByID(ctx, job.InitiatorID) + user, err := s.Database.GetUserByID(ctx, job.InitiatorID) if err != nil { return nil, failJob(fmt.Sprintf("get user: %s", err)) } @@ -169,38 +232,38 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac if err != nil { return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) } - workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) + workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) if err != nil { return nil, failJob(fmt.Sprintf("get workspace build: %s", err)) } - workspace, err := server.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) + workspace, err := s.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) if err != nil { return nil, failJob(fmt.Sprintf("get workspace: %s", err)) } - templateVersion, err := server.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID) + templateVersion, err := s.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID) if err != nil { return nil, failJob(fmt.Sprintf("get template version: %s", err)) } - templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, templateVersion.ID) + templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, failJob(fmt.Sprintf("get template version variables: %s", err)) } - template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) + template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) if err != nil { return nil, failJob(fmt.Sprintf("get template: %s", err)) } - owner, err := server.Database.GetUserByID(ctx, workspace.OwnerID) + owner, err := s.Database.GetUserByID(ctx, workspace.OwnerID) if err != nil { return nil, failJob(fmt.Sprintf("get owner: %s", err)) } - err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{}) + err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{}) if err != nil { return nil, failJob(fmt.Sprintf("publish workspace update: %s", err)) } var workspaceOwnerOIDCAccessToken string - if server.OIDCConfig != nil { - workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, server.Database, server.OIDCConfig, owner.ID) + if s.OIDCConfig != nil { + workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Database, s.OIDCConfig, owner.ID) if err != nil { return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err)) } @@ -209,12 +272,12 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac var sessionToken string switch workspaceBuild.Transition { case database.WorkspaceTransitionStart: - sessionToken, err = server.regenerateSessionToken(ctx, owner, workspace) + sessionToken, err = s.regenerateSessionToken(ctx, owner, workspace) if err != nil { return nil, failJob(fmt.Sprintf("regenerate session token: %s", err)) } case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete: - err = deleteSessionToken(ctx, server.Database, workspace) + err = deleteSessionToken(ctx, s.Database, workspace) if err != nil { return nil, failJob(fmt.Sprintf("delete session token: %s", err)) } @@ -225,14 +288,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac return nil, failJob(fmt.Sprintf("convert workspace transition: %s", err)) } - workspaceBuildParameters, err := server.Database.GetWorkspaceBuildParameters(ctx, workspaceBuild.ID) + workspaceBuildParameters, err := s.Database.GetWorkspaceBuildParameters(ctx, workspaceBuild.ID) if err != nil { return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err)) } gitAuthProviders := []*sdkproto.GitAuthProvider{} for _, p := range templateVersion.GitAuthProviders { - link, err := server.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + link, err := s.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ ProviderID: p, UserID: owner.ID, }) @@ -243,7 +306,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac return nil, failJob(fmt.Sprintf("acquire git auth link: %s", err)) } var config *gitauth.Config - for _, c := range server.GitAuthConfigs { + for _, c := range s.GitAuthConfigs { if c.ID != p { continue } @@ -252,14 +315,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac } // We weren't able to find a matching config for the ID! if config == nil { - server.Logger.Warn(ctx, "workspace build job is missing git provider", + s.Logger.Warn(ctx, "workspace build job is missing git provider", slog.F("git_provider_id", p), slog.F("template_version_id", templateVersion.ID), slog.F("workspace_id", workspaceBuild.WorkspaceID)) continue } - link, valid, err := config.RefreshToken(ctx, server.Database, link) + link, valid, err := config.RefreshToken(ctx, s.Database, link) if err != nil { return nil, failJob(fmt.Sprintf("refresh git auth link %q: %s", p, err)) } @@ -281,7 +344,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac VariableValues: asVariableValues(templateVariables), GitAuthProviders: gitAuthProviders, Metadata: &sdkproto.Metadata{ - CoderUrl: server.AccessURL.String(), + CoderUrl: s.AccessURL.String(), WorkspaceTransition: transition, WorkspaceName: workspace.Name, WorkspaceOwner: owner.Username, @@ -303,11 +366,11 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) } - templateVersion, err := server.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID) + templateVersion, err := s.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID) if err != nil { return nil, failJob(fmt.Sprintf("get template version: %s", err)) } - templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, templateVersion.ID) + templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, failJob(fmt.Sprintf("get template version variables: %s", err)) } @@ -317,7 +380,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac RichParameterValues: convertRichParameterValues(input.RichParameterValues), VariableValues: asVariableValues(templateVariables), Metadata: &sdkproto.Metadata{ - CoderUrl: server.AccessURL.String(), + CoderUrl: s.AccessURL.String(), WorkspaceName: input.WorkspaceName, }, }, @@ -329,7 +392,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) } - userVariableValues, err := server.includeLastVariableValues(ctx, input.TemplateVersionID, input.UserVariableValues) + userVariableValues, err := s.includeLastVariableValues(ctx, input.TemplateVersionID, input.UserVariableValues) if err != nil { return nil, failJob(err.Error()) } @@ -338,14 +401,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac TemplateImport: &proto.AcquiredJob_TemplateImport{ UserVariableValues: convertVariableValues(userVariableValues), Metadata: &sdkproto.Metadata{ - CoderUrl: server.AccessURL.String(), + CoderUrl: s.AccessURL.String(), }, }, } } switch job.StorageMethod { case database.ProvisionerStorageMethodFile: - file, err := server.Database.GetFileByID(ctx, job.FileID) + file, err := s.Database.GetFileByID(ctx, job.FileID) if err != nil { return nil, failJob(fmt.Sprintf("get file by hash: %s", err)) } @@ -360,7 +423,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac return protoJob, err } -func (server *Server) includeLastVariableValues(ctx context.Context, templateVersionID uuid.UUID, userVariableValues []codersdk.VariableValue) ([]codersdk.VariableValue, error) { +func (s *server) includeLastVariableValues(ctx context.Context, templateVersionID uuid.UUID, userVariableValues []codersdk.VariableValue) ([]codersdk.VariableValue, error) { var values []codersdk.VariableValue values = append(values, userVariableValues...) @@ -368,7 +431,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer return values, nil } - templateVersion, err := server.Database.GetTemplateVersionByID(ctx, templateVersionID) + templateVersion, err := s.Database.GetTemplateVersionByID(ctx, templateVersionID) if err != nil { return nil, xerrors.Errorf("get template version: %w", err) } @@ -377,7 +440,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer return values, nil } - template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) + template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) if err != nil { return nil, xerrors.Errorf("get template: %w", err) } @@ -386,7 +449,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer return values, nil } - templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID) + templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, xerrors.Errorf("get template version variables: %w", err) } @@ -412,8 +475,8 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer return values, nil } -func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) { - ctx, span := server.startTrace(ctx, tracing.FuncName()) +func (s *server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) { + ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. @@ -423,7 +486,7 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot return nil, xerrors.Errorf("parse job id: %w", err) } - job, err := server.Database.GetProvisionerJobByID(ctx, jobID) + job, err := s.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get job: %w", err) } @@ -431,11 +494,11 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot return nil, xerrors.New("job isn't running yet") } - if job.WorkerID.UUID.String() != server.ID.String() { + if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.New("you don't own this job") } - q := server.QuotaCommitter.Load() + q := s.QuotaCommitter.Load() if q == nil { // We're probably in community edition or a test. return &proto.CommitQuotaResponse{ @@ -446,8 +509,8 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot return (*q).CommitQuota(ctx, request) } -func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { - ctx, span := server.startTrace(ctx, tracing.FuncName()) +func (s *server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. @@ -456,18 +519,18 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } - server.Logger.Debug(ctx, "stage UpdateJob starting", slog.F("job_id", parsedID)) - job, err := server.Database.GetProvisionerJobByID(ctx, parsedID) + s.Logger.Debug(ctx, "stage UpdateJob starting", slog.F("job_id", parsedID)) + job, err := s.Database.GetProvisionerJobByID(ctx, parsedID) if err != nil { return nil, xerrors.Errorf("get job: %w", err) } if !job.WorkerID.Valid { return nil, xerrors.New("job isn't running yet") } - if job.WorkerID.UUID.String() != server.ID.String() { + if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.New("you don't own this job") } - err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + err = s.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ ID: parsedID, UpdatedAt: database.Now(), }) @@ -493,37 +556,37 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq insertParams.Stage = append(insertParams.Stage, log.Stage) insertParams.Source = append(insertParams.Source, logSource) insertParams.Output = append(insertParams.Output, log.Output) - server.Logger.Debug(ctx, "job log", + s.Logger.Debug(ctx, "job log", slog.F("job_id", parsedID), slog.F("stage", log.Stage), slog.F("output", log.Output)) } - logs, err := server.Database.InsertProvisionerJobLogs(ctx, insertParams) + logs, err := s.Database.InsertProvisionerJobLogs(ctx, insertParams) if err != nil { - server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) + s.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("insert job logs: %w", err) } // Publish by the lowest log ID inserted so the log stream will fetch // everything from that point. lowestID := logs[0].ID - server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) + s.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{ CreatedAfter: lowestID - 1, }) if err != nil { return nil, xerrors.Errorf("marshal: %w", err) } - err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data) + err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data) if err != nil { - server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) + s.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("publish job logs: %w", err) } - server.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID)) + s.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID)) } if len(request.Readme) > 0 { - err := server.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{ + err := s.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{ JobID: job.ID, Readme: string(request.Readme), UpdatedAt: database.Now(), @@ -534,16 +597,16 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq } if len(request.TemplateVariables) > 0 { - templateVersion, err := server.Database.GetTemplateVersionByJobID(ctx, job.ID) + templateVersion, err := s.Database.GetTemplateVersionByJobID(ctx, job.ID) if err != nil { - server.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err)) + s.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("get template version by job id: %w", err) } var variableValues []*sdkproto.VariableValue var variablesWithMissingValues []string for _, templateVariable := range request.TemplateVariables { - server.Logger.Debug(ctx, "insert template variable", slog.F("template_version_id", templateVersion.ID), slog.F("template_variable", redactTemplateVariable(templateVariable))) + s.Logger.Debug(ctx, "insert template variable", slog.F("template_version_id", templateVersion.ID), slog.F("template_variable", redactTemplateVariable(templateVariable))) value := templateVariable.DefaultValue for _, v := range request.UserVariableValues { @@ -563,7 +626,7 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq Sensitive: templateVariable.Sensitive, }) - _, err = server.Database.InsertTemplateVersionVariable(ctx, database.InsertTemplateVersionVariableParams{ + _, err = s.Database.InsertTemplateVersionVariable(ctx, database.InsertTemplateVersionVariableParams{ TemplateVersionID: templateVersion.ID, Name: templateVariable.Name, Description: templateVariable.Description, @@ -593,8 +656,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq }, nil } -func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) { - ctx, span := server.startTrace(ctx, tracing.FuncName()) +func (s *server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) { + ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. @@ -603,12 +666,12 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } - server.Logger.Debug(ctx, "stage FailJob starting", slog.F("job_id", jobID)) - job, err := server.Database.GetProvisionerJobByID(ctx, jobID) + s.Logger.Debug(ctx, "stage FailJob starting", slog.F("job_id", jobID)) + job, err := s.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get provisioner job: %w", err) } - if job.WorkerID.UUID.String() != server.ID.String() { + if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.New("you don't own this job") } if job.CompletedAt.Valid { @@ -627,7 +690,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p Valid: failJob.ErrorCode != "", } - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, CompletedAt: job.CompletedAt, UpdatedAt: database.Now(), @@ -637,7 +700,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p if err != nil { return nil, xerrors.Errorf("update provisioner job: %w", err) } - server.Telemetry.Report(&telemetry.Snapshot{ + s.Telemetry.Report(&telemetry.Snapshot{ ProvisionerJobs: []telemetry.ProvisionerJob{telemetry.ConvertProvisionerJob(job)}, }) @@ -650,7 +713,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p } var build database.WorkspaceBuild - err = server.Database.InTx(func(db database.Store) error { + err = s.Database.InTx(func(db database.Store) error { build, err = db.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) if err != nil { return xerrors.Errorf("get workspace build: %w", err) @@ -675,7 +738,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p return nil, err } - err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{}) + err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{}) if err != nil { return nil, xerrors.Errorf("update workspace: %w", err) } @@ -684,18 +747,18 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p // if failed job is a workspace build, audit the outcome if job.Type == database.ProvisionerJobTypeWorkspaceBuild { - auditor := server.Auditor.Load() - build, err := server.Database.GetWorkspaceBuildByJobID(ctx, job.ID) + auditor := s.Auditor.Load() + build, err := s.Database.GetWorkspaceBuildByJobID(ctx, job.ID) if err != nil { - server.Logger.Error(ctx, "audit log - get build", slog.Error(err)) + s.Logger.Error(ctx, "audit log - get build", slog.Error(err)) } else { auditAction := auditActionFromTransition(build.Transition) - workspace, err := server.Database.GetWorkspaceByID(ctx, build.WorkspaceID) + workspace, err := s.Database.GetWorkspaceByID(ctx, build.WorkspaceID) if err != nil { - server.Logger.Error(ctx, "audit log - get workspace", slog.Error(err)) + s.Logger.Error(ctx, "audit log - get workspace", slog.Error(err)) } else { previousBuildNumber := build.BuildNumber - 1 - previousBuild, prevBuildErr := server.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspace.ID, BuildNumber: previousBuildNumber, }) @@ -713,12 +776,12 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p wriBytes, err := json.Marshal(buildResourceInfo) if err != nil { - server.Logger.Error(ctx, "marshal workspace resource info for failed job", slog.Error(err)) + s.Logger.Error(ctx, "marshal workspace resource info for failed job", slog.Error(err)) } audit.BuildAudit(ctx, &audit.BuildAuditParams[database.WorkspaceBuild]{ Audit: *auditor, - Log: server.Logger, + Log: s.Logger, UserID: job.InitiatorID, JobID: job.ID, Action: auditAction, @@ -735,17 +798,17 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } - err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) + err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) if err != nil { - server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) return nil, xerrors.Errorf("publish end of job logs: %w", err) } return &proto.Empty{}, nil } // CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. -func (server *Server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { - ctx, span := server.startTrace(ctx, tracing.FuncName()) +func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { + ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. @@ -754,18 +817,18 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } - server.Logger.Debug(ctx, "stage CompleteJob starting", slog.F("job_id", jobID)) - job, err := server.Database.GetProvisionerJobByID(ctx, jobID) + s.Logger.Debug(ctx, "stage CompleteJob starting", slog.F("job_id", jobID)) + job, err := s.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get job by id: %w", err) } - if job.WorkerID.UUID.String() != server.ID.String() { + if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.Errorf("you don't own this job") } telemetrySnapshot := &telemetry.Snapshot{} // Items are added to this snapshot as they complete! - defer server.Telemetry.Report(telemetrySnapshot) + defer s.Telemetry.Report(telemetrySnapshot) switch jobType := completed.Type.(type) { case *proto.CompletedJob_TemplateImport_: @@ -780,13 +843,13 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources, } { for _, resource := range resources { - server.Logger.Info(ctx, "inserting template import job resource", + s.Logger.Info(ctx, "inserting template import job resource", slog.F("job_id", job.ID.String()), slog.F("resource_name", resource.Name), slog.F("resource_type", resource.Type), slog.F("transition", transition)) - err = InsertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot) + err = InsertWorkspaceResource(ctx, s.Database, jobID, transition, resource, telemetrySnapshot) if err != nil { return nil, xerrors.Errorf("insert resource: %w", err) } @@ -794,7 +857,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete } for _, richParameter := range jobType.TemplateImport.RichParameters { - server.Logger.Info(ctx, "inserting template import job parameter", + s.Logger.Info(ctx, "inserting template import job parameter", slog.F("job_id", job.ID.String()), slog.F("parameter_name", richParameter.Name), slog.F("type", richParameter.Type), @@ -819,7 +882,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete } } - _, err = server.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{ + _, err = s.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{ TemplateVersionID: input.TemplateVersionID, Name: richParameter.Name, DisplayName: richParameter.DisplayName, @@ -847,7 +910,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete for _, gitAuthProvider := range jobType.TemplateImport.GitAuthProviders { contains := false - for _, configuredProvider := range server.GitAuthConfigs { + for _, configuredProvider := range s.GitAuthConfigs { if configuredProvider.ID == gitAuthProvider { contains = true break @@ -862,7 +925,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete } } - err = server.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{ + err = s.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{ JobID: jobID, GitAuthProviders: jobType.TemplateImport.GitAuthProviders, UpdatedAt: database.Now(), @@ -871,7 +934,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete return nil, xerrors.Errorf("update template version git auth providers: %w", err) } - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: database.Now(), CompletedAt: sql.NullTime{ @@ -883,7 +946,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete if err != nil { return nil, xerrors.Errorf("update provisioner job: %w", err) } - server.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) + s.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) if err != nil { return nil, xerrors.Errorf("complete job: %w", err) } @@ -894,7 +957,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete return nil, xerrors.Errorf("unmarshal job data: %w", err) } - workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) + workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) if err != nil { return nil, xerrors.Errorf("get workspace build: %w", err) } @@ -902,14 +965,14 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete var workspace database.Workspace var getWorkspaceError error - err = server.Database.InTx(func(db database.Store) error { - // It's important we use server.timeNow() here because we want to be + err = s.Database.InTx(func(db database.Store) error { + // It's important we use s.timeNow() here because we want to be // able to customize the current time from within tests. - now := server.timeNow() + now := s.timeNow() workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) if getWorkspaceError != nil { - server.Logger.Error(ctx, + s.Logger.Error(ctx, "fetch workspace for build", slog.F("workspace_build_id", workspaceBuild.ID), slog.F("workspace_id", workspaceBuild.WorkspaceID), @@ -919,8 +982,8 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{ Database: db, - TemplateScheduleStore: *server.TemplateScheduleStore.Load(), - UserQuietHoursScheduleStore: *server.UserQuietHoursScheduleStore.Load(), + TemplateScheduleStore: *s.TemplateScheduleStore.Load(), + UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(), Now: now, Workspace: workspace, }) @@ -976,7 +1039,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete var updates []<-chan time.Time for _, d := range timeouts { - server.Logger.Debug(ctx, "triggering workspace notification after agent timeout", + s.Logger.Debug(ctx, "triggering workspace notification after agent timeout", slog.F("workspace_build_id", workspaceBuild.ID), slog.F("timeout", d), ) @@ -988,11 +1051,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete for _, wait := range updates { // Wait for the next potential timeout to occur. Note that we // can't listen on the context here because we will hang around - // after this function has returned. The server also doesn't + // after this function has returned. The s also doesn't // have a shutdown signal we can listen to. <-wait - if err := server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil { - server.Logger.Error(ctx, "workspace notification after agent timeout failed", + if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil { + s.Logger.Error(ctx, "workspace notification after agent timeout failed", slog.F("workspace_build_id", workspaceBuild.ID), slog.Error(err), ) @@ -1022,11 +1085,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete // audit the outcome of the workspace build if getWorkspaceError == nil { - auditor := server.Auditor.Load() + auditor := s.Auditor.Load() auditAction := auditActionFromTransition(workspaceBuild.Transition) previousBuildNumber := workspaceBuild.BuildNumber - 1 - previousBuild, prevBuildErr := server.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspace.ID, BuildNumber: previousBuildNumber, }) @@ -1044,12 +1107,12 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete wriBytes, err := json.Marshal(buildResourceInfo) if err != nil { - server.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err)) + s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err)) } audit.BuildAudit(ctx, &audit.BuildAuditParams[database.WorkspaceBuild]{ Audit: *auditor, - Log: server.Logger, + Log: s.Logger, UserID: job.InitiatorID, JobID: job.ID, Action: auditAction, @@ -1060,24 +1123,24 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete }) } - err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}) + err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}) if err != nil { return nil, xerrors.Errorf("update workspace: %w", err) } case *proto.CompletedJob_TemplateDryRun_: for _, resource := range jobType.TemplateDryRun.Resources { - server.Logger.Info(ctx, "inserting template dry-run job resource", + s.Logger.Info(ctx, "inserting template dry-run job resource", slog.F("job_id", job.ID.String()), slog.F("resource_name", resource.Name), slog.F("resource_type", resource.Type)) - err = InsertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) + err = InsertWorkspaceResource(ctx, s.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) if err != nil { return nil, xerrors.Errorf("insert resource: %w", err) } } - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: database.Now(), CompletedAt: sql.NullTime{ @@ -1088,7 +1151,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete if err != nil { return nil, xerrors.Errorf("update provisioner job: %w", err) } - server.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID)) + s.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID)) if err != nil { return nil, xerrors.Errorf("complete job: %w", err) } @@ -1105,18 +1168,18 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } - err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) + err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) if err != nil { - server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) return nil, xerrors.Errorf("publish end of job logs: %w", err) } - server.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID)) + s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID)) return &proto.Empty{}, nil } -func (server *Server) startTrace(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { - return server.Tracer.Start(ctx, name, append(opts, trace.WithAttributes( +func (s *server) startTrace(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { + return s.Tracer.Start(ctx, name, append(opts, trace.WithAttributes( semconv.ServiceNameKey.String("coderd.provisionerd"), ))...) } @@ -1316,19 +1379,19 @@ func workspaceSessionTokenName(workspace database.Workspace) string { return fmt.Sprintf("%s_%s_session_token", workspace.OwnerID, workspace.ID) } -func (server *Server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) { +func (s *server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) { newkey, sessionToken, err := apikey.Generate(apikey.CreateParams{ UserID: user.ID, LoginType: user.LoginType, - DeploymentValues: server.DeploymentValues, + DeploymentValues: s.DeploymentValues, TokenName: workspaceSessionTokenName(workspace), - LifetimeSeconds: int64(server.DeploymentValues.MaxTokenLifetime.Value().Seconds()), + LifetimeSeconds: int64(s.DeploymentValues.MaxTokenLifetime.Value().Seconds()), }) if err != nil { return "", xerrors.Errorf("generate API key: %w", err) } - err = server.Database.InTx(func(tx database.Store) error { + err = s.Database.InTx(func(tx database.Store) error { err := deleteSessionToken(ctx, tx, workspace) if err != nil { return xerrors.Errorf("delete session token: %w", err) diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 5a317cd531530..eba7ec2cbe267 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -61,25 +61,28 @@ func TestAcquireJob(t *testing.T) { t.Parallel() db := dbfake.New() ps := pubsub.NewInMemory() - srv := &provisionerdserver.Server{ - ID: uuid.New(), - Logger: slogtest.Make(t, nil), - AccessURL: &url.URL{}, - Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Database: db, - Pubsub: ps, - Telemetry: telemetry.NewNoop(), - AcquireJobDebounce: time.Hour, - Auditor: mockAuditor(), - TemplateScheduleStore: testTemplateScheduleStore(), - UserQuietHoursScheduleStore: testUserQuietHoursScheduleStore(), - Tracer: trace.NewNoopTracerProvider().Tracer("noop"), - DeploymentValues: &codersdk.DeploymentValues{}, - } + srv := provisionerdserver.NewServer( + &url.URL{}, + uuid.New(), + slogtest.Make(t, nil), + []database.ProvisionerType{database.ProvisionerTypeEcho}, + nil, + db, + ps, + telemetry.NewNoop(), + trace.NewNoopTracerProvider().Tracer("noop"), + &atomic.Pointer[proto.QuotaCommitter]{}, + mockAuditor(), + testTemplateScheduleStore(), + testUserQuietHoursScheduleStore(), + &codersdk.DeploymentValues{}, + time.Hour, + provisionerdserver.Options{}, + ) job, err := srv.AcquireJob(context.Background(), nil) require.NoError(t, err) require.Equal(t, &proto.AcquiredJob{}, job) - _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + _, err = db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ ID: uuid.New(), InitiatorID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -93,15 +96,15 @@ func TestAcquireJob(t *testing.T) { }) t.Run("NoJobs", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srv, _, _ := setup(t, false, nil) job, err := srv.AcquireJob(context.Background(), nil) require.NoError(t, err) require.Equal(t, &proto.AcquiredJob{}, job) }) t.Run("InitiatorNotFound", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - _, err := srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + srv, db, _ := setup(t, false, nil) + _, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ ID: uuid.New(), InitiatorID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -114,50 +117,52 @@ func TestAcquireJob(t *testing.T) { }) t.Run("WorkspaceBuildJob", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - gitAuthProvider := "github" // Set the max session token lifetime so we can assert we // create an API key with an expiration within the bounds of the // deployment config. - srv.DeploymentValues.MaxTokenLifetime = clibase.Duration(time.Hour) - srv.GitAuthConfigs = []*gitauth.Config{{ - ID: gitAuthProvider, - OAuth2Config: &testutil.OAuth2Config{}, - }} + dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)} + gitAuthProvider := "github" + srv, db, ps := setup(t, false, &overrides{ + deploymentValues: dv, + gitAuthConfigs: []*gitauth.Config{{ + ID: gitAuthProvider, + OAuth2Config: &testutil.OAuth2Config{}, + }}, + }) ctx := context.Background() - user := dbgen.User(t, srv.Database, database.User{}) - link := dbgen.UserLink(t, srv.Database, database.UserLink{ + user := dbgen.User(t, db, database.User{}) + link := dbgen.UserLink(t, db, database.UserLink{ LoginType: database.LoginTypeOIDC, UserID: user.ID, OAuthExpiry: database.Now().Add(time.Hour), OAuthAccessToken: "access-token", }) - dbgen.GitAuthLink(t, srv.Database, database.GitAuthLink{ + dbgen.GitAuthLink(t, db, database.GitAuthLink{ ProviderID: gitAuthProvider, UserID: user.ID, }) - template := dbgen.Template(t, srv.Database, database.Template{ + template := dbgen.Template(t, db, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) - file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - versionFile := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{ + file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) + versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) - err := srv.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{ + err := db.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{ JobID: version.JobID, GitAuthProviders: []string{gitAuthProvider}, UpdatedAt: database.Now(), }) require.NoError(t, err) // Import version job - _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + _ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ ID: version.JobID, InitiatorID: user.ID, FileID: versionFile.ID, @@ -171,14 +176,14 @@ func TestAcquireJob(t *testing.T) { }, })), }) - _ = dbgen.TemplateVersionVariable(t, srv.Database, database.TemplateVersionVariable{ + _ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{ TemplateVersionID: version.ID, Name: "first", Value: "first_value", DefaultValue: "default_value", Sensitive: true, }) - _ = dbgen.TemplateVersionVariable(t, srv.Database, database.TemplateVersionVariable{ + _ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{ TemplateVersionID: version.ID, Name: "second", Value: "second_value", @@ -186,11 +191,11 @@ func TestAcquireJob(t *testing.T) { Required: true, Sensitive: false, }) - workspace := dbgen.Workspace(t, srv.Database, database.Workspace{ + workspace := dbgen.Workspace(t, db, database.Workspace{ TemplateID: template.ID, OwnerID: user.ID, }) - build := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 1, JobID: uuid.New(), @@ -198,7 +203,7 @@ func TestAcquireJob(t *testing.T) { Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + _ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ ID: build.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, @@ -212,7 +217,7 @@ func TestAcquireJob(t *testing.T) { startPublished := make(chan struct{}) var closed bool - closeStartSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { + closeStartSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { if !closed { close(startPublished) closed = true @@ -243,10 +248,10 @@ func TestAcquireJob(t *testing.T) { require.NotEmpty(t, sessionToken) toks := strings.Split(sessionToken, "-") require.Len(t, toks, 2, "invalid api key") - key, err := srv.Database.GetAPIKeyByID(ctx, toks[0]) + key, err := db.GetAPIKeyByID(ctx, toks[0]) require.NoError(t, err) - require.Equal(t, int64(srv.DeploymentValues.MaxTokenLifetime.Value().Seconds()), key.LifetimeSeconds) - require.WithinDuration(t, time.Now().Add(srv.DeploymentValues.MaxTokenLifetime.Value()), key.ExpiresAt, time.Minute) + require.Equal(t, int64(dv.MaxTokenLifetime.Value().Seconds()), key.LifetimeSeconds) + require.WithinDuration(t, time.Now().Add(dv.MaxTokenLifetime.Value()), key.ExpiresAt, time.Minute) want, err := json.Marshal(&proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ @@ -268,7 +273,7 @@ func TestAcquireJob(t *testing.T) { AccessToken: "access_token", }}, Metadata: &sdkproto.Metadata{ - CoderUrl: srv.AccessURL.String(), + CoderUrl: (&url.URL{}).String(), WorkspaceTransition: sdkproto.WorkspaceTransition_START, WorkspaceName: workspace.Name, WorkspaceOwner: user.Username, @@ -288,7 +293,7 @@ func TestAcquireJob(t *testing.T) { // Assert that we delete the session token whenever // a stop is issued. - stopbuild := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ + stopbuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 2, JobID: uuid.New(), @@ -296,7 +301,7 @@ func TestAcquireJob(t *testing.T) { Transition: database.WorkspaceTransitionStop, Reason: database.BuildReasonInitiator, }) - _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + _ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ ID: stopbuild.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, @@ -309,7 +314,7 @@ func TestAcquireJob(t *testing.T) { }) stopPublished := make(chan struct{}) - closeStopSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { + closeStopSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { close(stopPublished) }) require.NoError(t, err) @@ -327,19 +332,19 @@ func TestAcquireJob(t *testing.T) { // Validate that a session token is deleted during a stop job. sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken require.Empty(t, sessionToken) - _, err = srv.Database.GetAPIKeyByID(ctx, key.ID) + _, err = db.GetAPIKeyByID(ctx, key.ID) require.ErrorIs(t, err, sql.ErrNoRows) }) t.Run("TemplateVersionDryRun", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srv, db, _ := setup(t, false, nil) ctx := context.Background() - user := dbgen.User(t, srv.Database, database.User{}) - version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{}) - file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) + file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) + _ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, @@ -360,7 +365,7 @@ func TestAcquireJob(t *testing.T) { want, err := json.Marshal(&proto.AcquiredJob_TemplateDryRun_{ TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{ Metadata: &sdkproto.Metadata{ - CoderUrl: srv.AccessURL.String(), + CoderUrl: (&url.URL{}).String(), WorkspaceName: "testing", }, }, @@ -370,12 +375,12 @@ func TestAcquireJob(t *testing.T) { }) t.Run("TemplateVersionImport", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srv, db, _ := setup(t, false, nil) ctx := context.Background() - user := dbgen.User(t, srv.Database, database.User{}) - file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + user := dbgen.User(t, db, database.User{}) + file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) + _ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ FileID: file.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, @@ -392,7 +397,7 @@ func TestAcquireJob(t *testing.T) { want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ Metadata: &sdkproto.Metadata{ - CoderUrl: srv.AccessURL.String(), + CoderUrl: (&url.URL{}).String(), }, }, }) @@ -401,12 +406,12 @@ func TestAcquireJob(t *testing.T) { }) t.Run("TemplateVersionImportWithUserVariable", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srv, db, _ := setup(t, false, nil) - user := dbgen.User(t, srv.Database, database.User{}) - version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{}) - file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + user := dbgen.User(t, db, database.User{}) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) + file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) + _ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ FileID: file.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, @@ -435,7 +440,7 @@ func TestAcquireJob(t *testing.T) { {Name: "first", Sensitive: true, Value: "first_value"}, }, Metadata: &sdkproto.Metadata{ - CoderUrl: srv.AccessURL.String(), + CoderUrl: (&url.URL{}).String(), }, }, }) @@ -449,7 +454,7 @@ func TestUpdateJob(t *testing.T) { ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srv, _, _ := setup(t, false, nil) _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: "hello", }) @@ -462,8 +467,8 @@ func TestUpdateJob(t *testing.T) { }) t.Run("NotRunning", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + srv, db, _ := setup(t, false, nil) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, @@ -478,15 +483,15 @@ func TestUpdateJob(t *testing.T) { // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + srv, db, _ := setup(t, false, nil) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionDryRun, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: uuid.New(), Valid: true, @@ -500,17 +505,17 @@ func TestUpdateJob(t *testing.T) { require.ErrorContains(t, err, "you don't own this job") }) - setupJob := func(t *testing.T, srv *provisionerdserver.Server) uuid.UUID { - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + setupJob := func(t *testing.T, db database.Store, srvID uuid.UUID) uuid.UUID { + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeTemplateVersionImport, StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srv.ID, + UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -521,8 +526,9 @@ func TestUpdateJob(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job := setupJob(t, srv) + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{id: &srvID}) + job := setupJob(t, db, srvID) _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), }) @@ -531,12 +537,13 @@ func TestUpdateJob(t *testing.T) { t.Run("Logs", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job := setupJob(t, srv) + srvID := uuid.New() + srv, db, ps := setup(t, false, &overrides{id: &srvID}) + job := setupJob(t, db, srvID) published := make(chan struct{}) - closeListener, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) { + closeListener, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) { close(published) }) require.NoError(t, err) @@ -556,10 +563,11 @@ func TestUpdateJob(t *testing.T) { }) t.Run("Readme", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job := setupJob(t, srv) + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{id: &srvID}) + job := setupJob(t, db, srvID) versionID := uuid.New() - err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: job, }) @@ -570,7 +578,7 @@ func TestUpdateJob(t *testing.T) { }) require.NoError(t, err) - version, err := srv.Database.GetTemplateVersionByID(ctx, versionID) + version, err := db.GetTemplateVersionByID(ctx, versionID) require.NoError(t, err) require.Equal(t, "# hello world", version.Readme) }) @@ -582,10 +590,11 @@ func TestUpdateJob(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - srv := setup(t, false) - job := setupJob(t, srv) + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{id: &srvID}) + job := setupJob(t, db, srvID) versionID := uuid.New() - err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: job, }) @@ -618,7 +627,7 @@ func TestUpdateJob(t *testing.T) { require.NoError(t, err) require.Len(t, response.VariableValues, 2) - templateVariables, err := srv.Database.GetTemplateVersionVariables(ctx, versionID) + templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID) require.NoError(t, err) require.Len(t, templateVariables, 2) require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue) @@ -629,10 +638,11 @@ func TestUpdateJob(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - srv := setup(t, false) - job := setupJob(t, srv) + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{id: &srvID}) + job := setupJob(t, db, srvID) versionID := uuid.New() - err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: job, }) @@ -661,7 +671,7 @@ func TestUpdateJob(t *testing.T) { // Even though there is an error returned, variables are stored in the database // to show the schema in the site UI. - templateVariables, err := srv.Database.GetTemplateVersionVariables(ctx, versionID) + templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID) require.NoError(t, err) require.Len(t, templateVariables, 2) require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue) @@ -675,7 +685,7 @@ func TestFailJob(t *testing.T) { ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srv, _, _ := setup(t, false, nil) _, err := srv.FailJob(ctx, &proto.FailedJob{ JobId: "hello", }) @@ -689,15 +699,15 @@ func TestFailJob(t *testing.T) { // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + srv, db, _ := setup(t, false, nil) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionImport, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: uuid.New(), Valid: true, @@ -712,23 +722,24 @@ func TestFailJob(t *testing.T) { }) t.Run("AlreadyCompleted", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{id: &srvID}) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeTemplateVersionImport, StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srv.ID, + UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) - err = srv.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: job.ID, CompletedAt: sql.NullTime{ Time: database.Now(), @@ -747,13 +758,14 @@ func TestFailJob(t *testing.T) { // // (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"} ignoreLogErrors := true - srv := setup(t, ignoreLogErrors) - workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + srvID := uuid.New() + srv, db, ps := setup(t, ignoreLogErrors, &overrides{id: &srvID}) + workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ ID: uuid.New(), }) require.NoError(t, err) buildID := uuid.New() - err = srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + err = db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ ID: buildID, WorkspaceID: workspace.ID, Transition: database.WorkspaceTransitionStart, @@ -765,7 +777,7 @@ func TestFailJob(t *testing.T) { }) require.NoError(t, err) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Input: input, Provisioner: database.ProvisionerTypeEcho, @@ -773,9 +785,9 @@ func TestFailJob(t *testing.T) { StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srv.ID, + UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -783,13 +795,13 @@ func TestFailJob(t *testing.T) { require.NoError(t, err) publishedWorkspace := make(chan struct{}) - closeWorkspaceSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { + closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { close(publishedWorkspace) }) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) - closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { + closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) @@ -806,7 +818,7 @@ func TestFailJob(t *testing.T) { require.NoError(t, err) <-publishedWorkspace <-publishedLogs - build, err := srv.Database.GetWorkspaceBuildByID(ctx, buildID) + build, err := db.GetWorkspaceBuildByID(ctx, buildID) require.NoError(t, err) require.Equal(t, "some state", string(build.ProvisionerState)) }) @@ -817,7 +829,7 @@ func TestCompleteJob(t *testing.T) { ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srv, _, _ := setup(t, false, nil) _, err := srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: "hello", }) @@ -831,15 +843,15 @@ func TestCompleteJob(t *testing.T) { // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + srv, db, _ := setup(t, false, nil) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeWorkspaceBuild, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: uuid.New(), Valid: true, @@ -852,17 +864,19 @@ func TestCompleteJob(t *testing.T) { }) require.ErrorContains(t, err, "you don't own this job") }) - t.Run("TemplateImport", func(t *testing.T) { + + t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{id: &srvID}) jobID := uuid.New() versionID := uuid.New() - err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: jobID, }) require.NoError(t, err) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: jobID, Provisioner: database.ProvisionerTypeEcho, Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), @@ -870,9 +884,9 @@ func TestCompleteJob(t *testing.T) { Type: database.ProvisionerJobTypeWorkspaceBuild, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srv.ID, + UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -895,14 +909,61 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) } completeJob() - job, err = srv.Database.GetProvisionerJobByID(ctx, job.ID) + job, err = db.GetProvisionerJobByID(ctx, job.ID) require.NoError(t, err) require.Contains(t, job.Error.String, `git auth provider "github" is not configured`) - srv.GitAuthConfigs = []*gitauth.Config{{ - ID: "github", - }} + }) + + t.Run("TemplateImport_WithGitAuth", func(t *testing.T) { + t.Parallel() + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{ + id: &srvID, + gitAuthConfigs: []*gitauth.Config{{ + ID: "github", + }}, + }) + jobID := uuid.New() + versionID := uuid.New() + err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + ID: versionID, + JobID: jobID, + }) + require.NoError(t, err) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: jobID, + Provisioner: database.ProvisionerTypeEcho, + Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + require.NoError(t, err) + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srvID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + completeJob := func() { + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + Type: &proto.CompletedJob_TemplateImport_{ + TemplateImport: &proto.CompletedJob_TemplateImport{ + StartResources: []*sdkproto.Resource{{ + Name: "hello", + Type: "aws_instance", + }}, + StopResources: []*sdkproto.Resource{}, + GitAuthProviders: []string{"github"}, + }, + }, + }) + require.NoError(t, err) + } completeJob() - job, err = srv.Database.GetProvisionerJobByID(ctx, job.ID) + job, err = db.GetProvisionerJobByID(ctx, job.ID) require.NoError(t, err) require.False(t, job.Error.Valid) }) @@ -1022,7 +1083,9 @@ func TestCompleteJob(t *testing.T) { t.Run(c.name, func(t *testing.T) { t.Parallel() - srv := setup(t, false) + srvID := uuid.New() + tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} + srv, db, ps := setup(t, false, &overrides{id: &srvID, templateScheduleStore: tss}) var store schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) { @@ -1035,14 +1098,14 @@ func TestCompleteJob(t *testing.T) { }, nil }, } - srv.TemplateScheduleStore.Store(&store) + tss.Store(&store) - user := dbgen.User(t, srv.Database, database.User{}) - template := dbgen.Template(t, srv.Database, database.Template{ + user := dbgen.User(t, db, database.User{}) + template := dbgen.Template(t, db, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) - err := srv.Database.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ + err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ ID: template.ID, UpdatedAt: database.Now(), AllowUserAutostart: c.templateAllowAutostop, @@ -1050,7 +1113,7 @@ func TestCompleteJob(t *testing.T) { MaxTTL: int64(c.templateMaxTTL), }) require.NoError(t, err) - file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) workspaceTTL := sql.NullInt64{} if c.workspaceTTL != 0 { workspaceTTL = sql.NullInt64{ @@ -1058,33 +1121,33 @@ func TestCompleteJob(t *testing.T) { Valid: true, } } - workspace := dbgen.Workspace(t, srv.Database, database.Workspace{ + workspace := dbgen.Workspace(t, db, database.Workspace{ TemplateID: template.ID, Ttl: workspaceTTL, }) - version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{ + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) - build := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, TemplateVersionID: version.ID, Transition: c.transition, Reason: database.BuildReasonInitiator, }) - job := dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ FileID: file.ID, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), }) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srv.ID, + UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -1092,13 +1155,13 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) publishedWorkspace := make(chan struct{}) - closeWorkspaceSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { + closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { close(publishedWorkspace) }) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) - closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { + closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) @@ -1121,11 +1184,11 @@ func TestCompleteJob(t *testing.T) { <-publishedWorkspace <-publishedLogs - workspace, err = srv.Database.GetWorkspaceByID(ctx, workspace.ID) + workspace, err = db.GetWorkspaceByID(ctx, workspace.ID) require.NoError(t, err) require.Equal(t, c.transition == database.WorkspaceTransitionDelete, workspace.Deleted) - workspaceBuild, err := srv.Database.GetWorkspaceBuildByID(ctx, build.ID) + workspaceBuild, err := db.GetWorkspaceBuildByID(ctx, build.ID) require.NoError(t, err) if c.expectedTTL == 0 { @@ -1229,14 +1292,20 @@ func TestCompleteJob(t *testing.T) { t.Run(c.name, func(t *testing.T) { t.Parallel() - srv := setup(t, false) - + srvID := uuid.New() // Simulate the given time starting from now. require.False(t, c.now.IsZero()) start := time.Now() - srv.TimeNowFn = func() time.Time { - return c.now.Add(time.Since(start)) - } + tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} + uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{} + srv, db, ps := setup(t, false, &overrides{ + timeNowFn: func() time.Time { + return c.now.Add(time.Since(start)) + }, + templateScheduleStore: tss, + userQuietHoursScheduleStore: uqhss, + id: &srvID, + }) var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) { @@ -1249,7 +1318,7 @@ func TestCompleteJob(t *testing.T) { }, nil }, } - srv.TemplateScheduleStore.Store(&templateScheduleStore) + tss.Store(&templateScheduleStore) var userQuietHoursScheduleStore schedule.UserQuietHoursScheduleStore = schedule.MockUserQuietHoursScheduleStore{ GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.UserQuietHoursScheduleOptions, error) { @@ -1270,16 +1339,16 @@ func TestCompleteJob(t *testing.T) { }, nil }, } - srv.UserQuietHoursScheduleStore.Store(&userQuietHoursScheduleStore) + uqhss.Store(&userQuietHoursScheduleStore) - user := dbgen.User(t, srv.Database, database.User{ + user := dbgen.User(t, db, database.User{ QuietHoursSchedule: c.userQuietHoursSchedule, }) - template := dbgen.Template(t, srv.Database, database.Template{ + template := dbgen.Template(t, db, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) - err := srv.Database.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ + err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ ID: template.ID, UpdatedAt: database.Now(), AllowUserAutostart: false, @@ -1289,9 +1358,9 @@ func TestCompleteJob(t *testing.T) { RestartRequirementWeeks: c.templateRestartRequirement.Weeks, }) require.NoError(t, err) - template, err = srv.Database.GetTemplateByID(ctx, template.ID) + template, err = db.GetTemplateByID(ctx, template.ID) require.NoError(t, err) - file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) workspaceTTL := sql.NullInt64{} if c.workspaceTTL != 0 { workspaceTTL = sql.NullInt64{ @@ -1299,34 +1368,34 @@ func TestCompleteJob(t *testing.T) { Valid: true, } } - workspace := dbgen.Workspace(t, srv.Database, database.Workspace{ + workspace := dbgen.Workspace(t, db, database.Workspace{ TemplateID: template.ID, Ttl: workspaceTTL, OwnerID: user.ID, }) - version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{ + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) - build := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, TemplateVersionID: version.ID, Transition: c.transition, Reason: database.BuildReasonInitiator, }) - job := dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ FileID: file.ID, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), }) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srv.ID, + UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -1334,13 +1403,13 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) publishedWorkspace := make(chan struct{}) - closeWorkspaceSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { + closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { close(publishedWorkspace) }) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) - closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { + closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) @@ -1363,11 +1432,11 @@ func TestCompleteJob(t *testing.T) { <-publishedWorkspace <-publishedLogs - workspace, err = srv.Database.GetWorkspaceByID(ctx, workspace.ID) + workspace, err = db.GetWorkspaceByID(ctx, workspace.ID) require.NoError(t, err) require.Equal(t, c.transition == database.WorkspaceTransitionDelete, workspace.Deleted) - workspaceBuild, err := srv.Database.GetWorkspaceBuildByID(ctx, build.ID) + workspaceBuild, err := db.GetWorkspaceBuildByID(ctx, build.ID) require.NoError(t, err) // If the max deadline is set, the deadline should also be set. @@ -1392,17 +1461,18 @@ func TestCompleteJob(t *testing.T) { }) t.Run("TemplateDryRun", func(t *testing.T) { t.Parallel() - srv := setup(t, false) - job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + srvID := uuid.New() + srv, db, _ := setup(t, false, &overrides{id: &srvID}) + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeTemplateVersionDryRun, StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) - _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srv.ID, + UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -1519,30 +1589,81 @@ func TestInsertWorkspaceResource(t *testing.T) { }) } -func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server { +type overrides struct { + deploymentValues *codersdk.DeploymentValues + gitAuthConfigs []*gitauth.Config + id *uuid.UUID + templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] + userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] + timeNowFn func() time.Time +} + +func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) { t.Helper() db := dbfake.New() ps := pubsub.NewInMemory() - - return &provisionerdserver.Server{ - ID: uuid.New(), - Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), - OIDCConfig: &oauth2.Config{}, - AccessURL: &url.URL{}, - Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho}, - Database: db, - Pubsub: ps, - Telemetry: telemetry.NewNoop(), - Auditor: mockAuditor(), - TemplateScheduleStore: testTemplateScheduleStore(), - UserQuietHoursScheduleStore: testUserQuietHoursScheduleStore(), - Tracer: trace.NewNoopTracerProvider().Tracer("noop"), - DeploymentValues: &codersdk.DeploymentValues{}, - - // Negative values cause the debounce to never kick in. Tests that want - // to test debounce can override this value. - AcquireJobDebounce: -time.Minute, + deploymentValues := &codersdk.DeploymentValues{} + var gitAuthConfigs []*gitauth.Config + srvID := uuid.New() + tss := testTemplateScheduleStore() + uqhss := testUserQuietHoursScheduleStore() + var timeNowFn func() time.Time + if ov != nil { + if ov.deploymentValues != nil { + deploymentValues = ov.deploymentValues + } + if ov.gitAuthConfigs != nil { + gitAuthConfigs = ov.gitAuthConfigs + } + if ov.id != nil { + srvID = *ov.id + } + if ov.templateScheduleStore != nil { + ttss := tss.Load() + // keep the initial test value if the override hasn't set the atomic pointer. + tss = ov.templateScheduleStore + if tss.Load() == nil { + swapped := tss.CompareAndSwap(nil, ttss) + require.True(t, swapped) + } + } + if ov.userQuietHoursScheduleStore != nil { + tuqhss := uqhss.Load() + // keep the initial test value if the override hasn't set the atomic pointer. + uqhss = ov.userQuietHoursScheduleStore + if uqhss.Load() == nil { + swapped := uqhss.CompareAndSwap(nil, tuqhss) + require.True(t, swapped) + } + } + if ov.timeNowFn != nil { + timeNowFn = ov.timeNowFn + } } + + return provisionerdserver.NewServer( + &url.URL{}, + srvID, + slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), + []database.ProvisionerType{database.ProvisionerTypeEcho}, + nil, + db, + ps, + telemetry.NewNoop(), + trace.NewNoopTracerProvider().Tracer("noop"), + &atomic.Pointer[proto.QuotaCommitter]{}, + mockAuditor(), + tss, + uqhss, + deploymentValues, + // Negative values cause the debounce to never kick in. + -time.Minute, + provisionerdserver.Options{ + GitAuthConfigs: gitAuthConfigs, + TimeNowFn: timeNowFn, + OIDCConfig: &oauth2.Config{}, + }, + ), db, ps } func must[T any](value T, err error) T { diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index b82c2a6c750f1..31cff09b19dfe 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "strings" + "time" "github.com/google/uuid" "github.com/hashicorp/yamux" @@ -243,23 +244,28 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) return } mux := drpcmux.New() - err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ - AccessURL: api.AccessURL, - GitAuthConfigs: api.GitAuthConfigs, - OIDCConfig: api.OIDCConfig, - ID: daemon.ID, - Database: api.Database, - Pubsub: api.Pubsub, - Provisioners: daemon.Provisioners, - Telemetry: api.Telemetry, - Auditor: &api.AGPL.Auditor, - TemplateScheduleStore: api.AGPL.TemplateScheduleStore, - UserQuietHoursScheduleStore: api.AGPL.UserQuietHoursScheduleStore, - Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), - Tags: rawTags, - Tracer: trace.NewNoopTracerProvider().Tracer("noop"), - DeploymentValues: api.DeploymentValues, - }) + debounce := time.Second + err = proto.DRPCRegisterProvisionerDaemon(mux, provisionerdserver.NewServer( + api.AccessURL, + daemon.ID, + api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), + daemon.Provisioners, + rawTags, + api.Database, + api.Pubsub, + api.Telemetry, + trace.NewNoopTracerProvider().Tracer("noop"), + &api.AGPL.QuotaCommitter, + &api.AGPL.Auditor, + api.AGPL.TemplateScheduleStore, + api.AGPL.UserQuietHoursScheduleStore, + api.DeploymentValues, + debounce, + provisionerdserver.Options{ + GitAuthConfigs: api.GitAuthConfigs, + OIDCConfig: api.OIDCConfig, + }, + )) if err != nil { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err)) return diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index e190a3df90e20..d41bf42385e76 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -9,13 +9,20 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/provisioner/echo" + "github.com/coder/coder/v2/provisionerd" + provisionerdproto "github.com/coder/coder/v2/provisionerd/proto" + "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/testutil" ) @@ -212,6 +219,107 @@ func TestProvisionerDaemonServe(t *testing.T) { require.Len(t, daemons, 1) }) + t.Run("PSK_daily_cost", func(t *testing.T) { + t.Parallel() + client, user := coderdenttest.New(t, &coderdenttest.Options{ + UserWorkspaceQuota: 10, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + codersdk.FeatureTemplateRBAC: 1, + }, + }, + ProvisionerDaemonPSK: "provisionersftw", + }) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + terraformClient, terraformServer := provisionersdk.MemTransportPipe() + go func() { + <-ctx.Done() + _ = terraformClient.Close() + _ = terraformServer.Close() + }() + + tempDir := t.TempDir() + errCh := make(chan error) + go func() { + err := echo.Serve(ctx, &provisionersdk.ServeOptions{ + Listener: terraformServer, + Logger: logger.Named("echo"), + WorkDirectory: tempDir, + }) + errCh <- err + }() + + provisioners := provisionerd.Provisioners{ + string(database.ProvisionerTypeEcho): proto.NewDRPCProvisionerClient(terraformClient), + } + another := codersdk.New(client.URL) + pd := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { + return another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, + PreSharedKey: "provisionersftw", + }) + }, &provisionerd.Options{ + Logger: logger.Named("provisionerd"), + Provisioners: provisioners, + }) + defer pd.Close() + + // Patch the 'Everyone' group to give the user quota to build their workspace. + _, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{ + QuotaAllowance: ptr.Ref(1), + }) + require.NoError(t, err) + + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionApply: []*proto.Response{{ + Type: &proto.Response_Apply{ + Apply: &proto.ApplyComplete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + DailyCost: 1, + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Name: "example", + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }}, + }, + }, + }}, + }) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + + err = pd.Shutdown(ctx) + require.NoError(t, err) + err = terraformServer.Close() + require.NoError(t, err) + select { + case <-ctx.Done(): + t.Error("timeout waiting for server to shut down") + case err := <-errCh: + require.NoError(t, err) + } + }) + t.Run("BadPSK", func(t *testing.T) { t.Parallel() client, user := coderdenttest.New(t, &coderdenttest.Options{ From 3c0d020a5cb7ce46536f0c7fc33a13d4c8439d96 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 29 Aug 2023 12:45:25 +0000 Subject: [PATCH 2/4] Add logging for debounce and job acquire Signed-off-by: Spike Curtis --- coderd/provisionerdserver/provisionerdserver.go | 1 + provisionerd/provisionerd.go | 2 ++ 2 files changed, 3 insertions(+) diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index f07513d83ad38..e88e47bb55999 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -154,6 +154,7 @@ func (s *server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Acquire // jobs are added at once, they will start after at most this duration. lastAcquireMutex.RLock() if !lastAcquire.IsZero() && time.Since(lastAcquire) < s.AcquireJobDebounce { + s.Logger.Debug(ctx, "debounce acquire job", slog.F("debounce", s.AcquireJobDebounce), slog.F("last_acquire", lastAcquire)) lastAcquireMutex.RUnlock() return &proto.AcquiredJob{}, nil } diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index a341bd5a3df85..81346fa25873b 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -308,6 +308,7 @@ func (p *Server) acquireJob(ctx context.Context) { lastAcquireMutex.RLock() if !lastAcquire.IsZero() && time.Since(lastAcquire) < p.opts.JobPollDebounce { lastAcquireMutex.RUnlock() + p.opts.Logger.Debug(ctx, "debounce acquire job") return } lastAcquireMutex.RUnlock() @@ -319,6 +320,7 @@ func (p *Server) acquireJob(ctx context.Context) { } job, err := client.AcquireJob(ctx, &proto.Empty{}) + p.opts.Logger.Debug(ctx, "called AcquireJob on client", slog.F("job_id", job.GetJobId()), slog.Error(err)) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, yamux.ErrSessionShutdown) || From a91d28f8281b4a95614d8150138e1e932560b399 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 30 Aug 2023 06:19:16 +0000 Subject: [PATCH 3/4] Return error instead of panic Signed-off-by: Spike Curtis --- coderd/coderd.go | 9 ++++++--- coderd/provisionerdserver/provisionerdserver.go | 14 +++++++------- .../provisionerdserver/provisionerdserver_test.go | 9 ++++++--- enterprise/coderd/provisionerdaemons.go | 9 +++++++-- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 27d566753909a..53c7b165baf9a 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1098,8 +1098,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti } mux := drpcmux.New() - - err = proto.DRPCRegisterProvisionerDaemon(mux, provisionerdserver.NewServer( + srv, err := provisionerdserver.NewServer( api.AccessURL, daemon.ID, api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), @@ -1119,7 +1118,11 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti OIDCConfig: api.OIDCConfig, GitAuthConfigs: api.GitAuthConfigs, }, - )) + ) + if err != nil { + return nil, err + } + err = proto.DRPCRegisterProvisionerDaemon(mux, srv) if err != nil { return nil, err } diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index e88e47bb55999..7aa7cbf88d83a 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -95,22 +95,22 @@ func NewServer( deploymentValues *codersdk.DeploymentValues, acquireJobDebounce time.Duration, options Options, -) proto.DRPCProvisionerDaemonServer { +) (proto.DRPCProvisionerDaemonServer, error) { // Panic early if pointers are nil if quotaCommitter == nil { - panic("quotaCommitter is nil") + return nil, xerrors.New("quotaCommitter is nil") } if auditor == nil { - panic("auditor is nil") + return nil, xerrors.New("auditor is nil") } if templateScheduleStore == nil { - panic("templateScheduleStore is nil") + return nil, xerrors.New("templateScheduleStore is nil") } if userQuietHoursScheduleStore == nil { - panic("userQuietHoursScheduleStore is nil") + return nil, xerrors.New("userQuietHoursScheduleStore is nil") } if deploymentValues == nil { - panic("deploymentValues is nil") + return nil, xerrors.New("deploymentValues is nil") } return &server{ AccessURL: accessURL, @@ -131,7 +131,7 @@ func NewServer( AcquireJobDebounce: acquireJobDebounce, OIDCConfig: options.OIDCConfig, TimeNowFn: options.TimeNowFn, - } + }, nil } // timeNow should be used when trying to get the current time for math diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index eba7ec2cbe267..7fcc5d19d20c4 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -61,7 +61,7 @@ func TestAcquireJob(t *testing.T) { t.Parallel() db := dbfake.New() ps := pubsub.NewInMemory() - srv := provisionerdserver.NewServer( + srv, err := provisionerdserver.NewServer( &url.URL{}, uuid.New(), slogtest.Make(t, nil), @@ -79,6 +79,7 @@ func TestAcquireJob(t *testing.T) { time.Hour, provisionerdserver.Options{}, ) + require.NoError(t, err) job, err := srv.AcquireJob(context.Background(), nil) require.NoError(t, err) require.Equal(t, &proto.AcquiredJob{}, job) @@ -1641,7 +1642,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi } } - return provisionerdserver.NewServer( + srv, err := provisionerdserver.NewServer( &url.URL{}, srvID, slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), @@ -1663,7 +1664,9 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi TimeNowFn: timeNowFn, OIDCConfig: &oauth2.Config{}, }, - ), db, ps + ) + require.NoError(t, err) + return srv, db, ps } func must[T any](value T, err error) T { diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 31cff09b19dfe..6315665c405c6 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -245,7 +245,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) } mux := drpcmux.New() debounce := time.Second - err = proto.DRPCRegisterProvisionerDaemon(mux, provisionerdserver.NewServer( + srv, err := provisionerdserver.NewServer( api.AccessURL, daemon.ID, api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), @@ -265,7 +265,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) GitAuthConfigs: api.GitAuthConfigs, OIDCConfig: api.OIDCConfig, }, - )) + ) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("create provisioner daemon server: %s", err)) + return + } + err = proto.DRPCRegisterProvisionerDaemon(mux, srv) if err != nil { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err)) return From 8e32e5ca6be399f715d4a7aa4e5723011e2d6b83 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 30 Aug 2023 10:04:16 +0000 Subject: [PATCH 4/4] remove debounce on external provisioners to fix test flakes Signed-off-by: Spike Curtis --- enterprise/coderd/provisionerdaemons.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 6315665c405c6..1129eb4004185 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -244,7 +244,6 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) return } mux := drpcmux.New() - debounce := time.Second srv, err := provisionerdserver.NewServer( api.AccessURL, daemon.ID, @@ -260,7 +259,8 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) api.AGPL.TemplateScheduleStore, api.AGPL.UserQuietHoursScheduleStore, api.DeploymentValues, - debounce, + // TODO(spikecurtis) - fix debounce to not cause flaky tests. + time.Duration(0), provisionerdserver.Options{ GitAuthConfigs: api.GitAuthConfigs, OIDCConfig: api.OIDCConfig,