From bff51bcf79be540d30c7d323b10dc54e8fe8f5a8 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 7 Nov 2022 21:56:33 +0000 Subject: [PATCH 1/5] chore: Separate the provisionerd server into it's own package This code should be thoroughly tested now that we understand the abstraction. I separated it to make our lives a bit easier for external provisioner daemons as well! --- .vscode/settings.json | 1 + coderd/provisionerdaemons.go | 905 +---------------- .../provisionerdserver/provisionerdserver.go | 932 ++++++++++++++++++ coderd/templateversions.go | 5 +- coderd/workspacebuilds.go | 3 +- coderd/workspaceresourceauth.go | 3 +- coderd/workspaces.go | 3 +- codersdk/workspaces.go | 7 + 8 files changed, 951 insertions(+), 908 deletions(-) create mode 100644 coderd/provisionerdserver/provisionerdserver.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 2f1443066768d..ba58f6f4ee1bf 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -88,6 +88,7 @@ "promptui", "protobuf", "provisionerd", + "provisionerdserver", "provisionersdk", "ptty", "ptys", diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 0fe8c096c868b..1336e84a6fb0c 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -3,20 +3,14 @@ package coderd import ( "context" "database/sql" - "encoding/json" "errors" "fmt" "io" "net/http" - "net/url" - "reflect" - "time" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" - "github.com/tabbed/pqtype" "golang.org/x/xerrors" - protobuf "google.golang.org/protobuf/proto" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -24,14 +18,11 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/coderd/parameter" + "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" - "github.com/coder/coder/provisioner" "github.com/coder/coder/provisionerd/proto" "github.com/coder/coder/provisionersdk" - sdkproto "github.com/coder/coder/provisionersdk/proto" ) func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { @@ -85,7 +76,7 @@ func (api *API) ListenProvisionerDaemon(ctx context.Context) (client proto.DRPCP } mux := drpcmux.New() - err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdServer{ + err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ AccessURL: api.AccessURL, ID: daemon.ID, Database: api.Database, @@ -117,895 +108,3 @@ func (api *API) ListenProvisionerDaemon(ctx context.Context) (client proto.DRPCP return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil } - -// The input for a "workspace_provision" job. -type workspaceProvisionJob struct { - WorkspaceBuildID uuid.UUID `json:"workspace_build_id"` - DryRun bool `json:"dry_run"` -} - -// The input for a "template_version_dry_run" job. -type templateVersionDryRunJob struct { - TemplateVersionID uuid.UUID `json:"template_version_id"` - WorkspaceName string `json:"workspace_name"` - ParameterValues []database.ParameterValue `json:"parameter_values"` -} - -// Implementation of the provisioner daemon protobuf server. -type provisionerdServer struct { - AccessURL *url.URL - ID uuid.UUID - Logger slog.Logger - Provisioners []database.ProvisionerType - Database database.Store - Pubsub database.Pubsub - Telemetry telemetry.Reporter -} - -// AcquireJob queries the database to lock a job. -func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { - // This marks the job as locked in the database. - job, err := server.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ - StartedAt: sql.NullTime{ - Time: database.Now(), - Valid: true, - }, - WorkerID: uuid.NullUUID{ - UUID: server.ID, - Valid: true, - }, - Types: server.Provisioners, - }) - if errors.Is(err, sql.ErrNoRows) { - // The provisioner daemon assumes no jobs are available if - // an empty struct is returned. - return &proto.AcquiredJob{}, nil - } - if err != nil { - return nil, xerrors.Errorf("acquire job: %w", err) - } - server.Logger.Debug(ctx, "locked job from database", slog.F("id", job.ID)) - - // Marks the acquired job as failed with the error message provided. - failJob := func(errorMessage string) error { - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: job.ID, - CompletedAt: sql.NullTime{ - Time: database.Now(), - Valid: true, - }, - Error: sql.NullString{ - String: errorMessage, - Valid: true, - }, - }) - if err != nil { - return xerrors.Errorf("update provisioner job: %w", err) - } - return xerrors.Errorf("request job was invalidated: %s", errorMessage) - } - - user, err := server.Database.GetUserByID(ctx, job.InitiatorID) - if err != nil { - return nil, failJob(fmt.Sprintf("get user: %s", err)) - } - - protoJob := &proto.AcquiredJob{ - JobId: job.ID.String(), - CreatedAt: job.CreatedAt.UnixMilli(), - Provisioner: string(job.Provisioner), - UserName: user.Username, - } - switch job.Type { - case database.ProvisionerJobTypeWorkspaceBuild: - var input workspaceProvisionJob - err = json.Unmarshal(job.Input, &input) - if err != nil { - return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) - } - workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) - if err != nil { - return nil, failJob(fmt.Sprintf("get workspace build: %s", err)) - } - workspace, err := server.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) - if err != nil { - return nil, failJob(fmt.Sprintf("get workspace: %s", err)) - } - templateVersion, err := server.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID) - if err != nil { - return nil, failJob(fmt.Sprintf("get template version: %s", err)) - } - template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) - if err != nil { - return nil, failJob(fmt.Sprintf("get template: %s", err)) - } - owner, err := server.Database.GetUserByID(ctx, workspace.OwnerID) - if err != nil { - return nil, failJob(fmt.Sprintf("get owner: %s", err)) - } - err = server.Pubsub.Publish(watchWorkspaceChannel(workspace.ID), []byte{}) - if err != nil { - return nil, failJob(fmt.Sprintf("publish workspace update: %s", err)) - } - - // Compute parameters for the workspace to consume. - parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{ - TemplateImportJobID: templateVersion.JobID, - TemplateID: uuid.NullUUID{ - UUID: template.ID, - Valid: true, - }, - WorkspaceID: uuid.NullUUID{ - UUID: workspace.ID, - Valid: true, - }, - }, nil) - if err != nil { - return nil, failJob(fmt.Sprintf("compute parameters: %s", err)) - } - - // Convert types to their corresponding protobuf types. - protoParameters, err := convertComputedParameterValues(parameters) - if err != nil { - return nil, failJob(fmt.Sprintf("convert computed parameters to protobuf: %s", err)) - } - transition, err := convertWorkspaceTransition(workspaceBuild.Transition) - if err != nil { - return nil, failJob(fmt.Sprintf("convert workspace transition: %s", err)) - } - - protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{ - WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ - WorkspaceBuildId: workspaceBuild.ID.String(), - WorkspaceName: workspace.Name, - State: workspaceBuild.ProvisionerState, - ParameterValues: protoParameters, - Metadata: &sdkproto.Provision_Metadata{ - CoderUrl: server.AccessURL.String(), - WorkspaceTransition: transition, - WorkspaceName: workspace.Name, - WorkspaceOwner: owner.Username, - WorkspaceOwnerEmail: owner.Email, - WorkspaceId: workspace.ID.String(), - WorkspaceOwnerId: owner.ID.String(), - }, - }, - } - case database.ProvisionerJobTypeTemplateVersionDryRun: - var input templateVersionDryRunJob - err = json.Unmarshal(job.Input, &input) - if err != nil { - return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) - } - - templateVersion, err := server.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID) - if err != nil { - return nil, failJob(fmt.Sprintf("get template version: %s", err)) - } - - // Compute parameters for the dry-run to consume. - parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{ - TemplateImportJobID: templateVersion.JobID, - TemplateID: templateVersion.TemplateID, - WorkspaceID: uuid.NullUUID{}, - AdditionalParameterValues: input.ParameterValues, - }, nil) - if err != nil { - return nil, failJob(fmt.Sprintf("compute parameters: %s", err)) - } - - // Convert types to their corresponding protobuf types. - protoParameters, err := convertComputedParameterValues(parameters) - if err != nil { - return nil, failJob(fmt.Sprintf("convert computed parameters to protobuf: %s", err)) - } - - protoJob.Type = &proto.AcquiredJob_TemplateDryRun_{ - TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{ - ParameterValues: protoParameters, - Metadata: &sdkproto.Provision_Metadata{ - CoderUrl: server.AccessURL.String(), - WorkspaceName: input.WorkspaceName, - }, - }, - } - case database.ProvisionerJobTypeTemplateVersionImport: - protoJob.Type = &proto.AcquiredJob_TemplateImport_{ - TemplateImport: &proto.AcquiredJob_TemplateImport{ - Metadata: &sdkproto.Provision_Metadata{ - CoderUrl: server.AccessURL.String(), - }, - }, - } - } - switch job.StorageMethod { - case database.ProvisionerStorageMethodFile: - file, err := server.Database.GetFileByID(ctx, job.FileID) - if err != nil { - return nil, failJob(fmt.Sprintf("get file by hash: %s", err)) - } - protoJob.TemplateSourceArchive = file.Data - default: - return nil, failJob(fmt.Sprintf("unsupported storage method: %s", job.StorageMethod)) - } - if protobuf.Size(protoJob) > provisionersdk.MaxMessageSize { - return nil, failJob(fmt.Sprintf("payload was too big: %d > %d", protobuf.Size(protoJob), provisionersdk.MaxMessageSize)) - } - - return protoJob, err -} - -func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { - parsedID, err := uuid.Parse(request.JobId) - if err != nil { - return nil, xerrors.Errorf("parse job id: %w", err) - } - server.Logger.Debug(ctx, "UpdateJob starting", slog.F("job_id", parsedID)) - job, err := server.Database.GetProvisionerJobByID(ctx, parsedID) - if err != nil { - return nil, xerrors.Errorf("get job: %w", err) - } - if !job.WorkerID.Valid { - return nil, xerrors.New("job isn't running yet") - } - if job.WorkerID.UUID.String() != server.ID.String() { - return nil, xerrors.New("you don't own this job") - } - err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ - ID: parsedID, - UpdatedAt: database.Now(), - }) - if err != nil { - return nil, xerrors.Errorf("update job: %w", err) - } - - if len(request.Logs) > 0 { - insertParams := database.InsertProvisionerJobLogsParams{ - JobID: parsedID, - } - for _, log := range request.Logs { - logLevel, err := convertLogLevel(log.Level) - if err != nil { - return nil, xerrors.Errorf("convert log level: %w", err) - } - logSource, err := convertLogSource(log.Source) - if err != nil { - return nil, xerrors.Errorf("convert log source: %w", err) - } - insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt)) - insertParams.Level = append(insertParams.Level, logLevel) - insertParams.Stage = append(insertParams.Stage, log.Stage) - insertParams.Source = append(insertParams.Source, logSource) - insertParams.Output = append(insertParams.Output, log.Output) - server.Logger.Debug(ctx, "job log", - slog.F("job_id", parsedID), - slog.F("stage", log.Stage), - slog.F("output", log.Output)) - } - logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams) - if err != nil { - server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) - return nil, xerrors.Errorf("insert job logs: %w", err) - } - // Publish by the lowest log ID inserted so the - // log stream will fetch everything from that point. - lowestID := logs[0].ID - server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) - data, err := json.Marshal(provisionerJobLogsMessage{ - CreatedAfter: lowestID, - }) - if err != nil { - return nil, xerrors.Errorf("marshal: %w", err) - } - err = server.Pubsub.Publish(provisionerJobLogsChannel(parsedID), data) - if err != nil { - server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) - return nil, xerrors.Errorf("publish job log: %w", err) - } - server.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID)) - } - - if len(request.Readme) > 0 { - err := server.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{ - JobID: job.ID, - Readme: string(request.Readme), - UpdatedAt: database.Now(), - }) - if err != nil { - return nil, xerrors.Errorf("update template version description: %w", err) - } - } - - if len(request.ParameterSchemas) > 0 { - for index, protoParameter := range request.ParameterSchemas { - validationTypeSystem, err := convertValidationTypeSystem(protoParameter.ValidationTypeSystem) - if err != nil { - return nil, xerrors.Errorf("convert validation type system for %q: %w", protoParameter.Name, err) - } - - parameterSchema := database.InsertParameterSchemaParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - JobID: job.ID, - Name: protoParameter.Name, - Description: protoParameter.Description, - RedisplayValue: protoParameter.RedisplayValue, - ValidationError: protoParameter.ValidationError, - ValidationCondition: protoParameter.ValidationCondition, - ValidationValueType: protoParameter.ValidationValueType, - ValidationTypeSystem: validationTypeSystem, - - DefaultSourceScheme: database.ParameterSourceSchemeNone, - DefaultDestinationScheme: database.ParameterDestinationSchemeNone, - - AllowOverrideDestination: protoParameter.AllowOverrideDestination, - AllowOverrideSource: protoParameter.AllowOverrideSource, - - Index: int32(index), - } - - // It's possible a parameter doesn't define a default source! - if protoParameter.DefaultSource != nil { - parameterSourceScheme, err := convertParameterSourceScheme(protoParameter.DefaultSource.Scheme) - if err != nil { - return nil, xerrors.Errorf("convert parameter source scheme: %w", err) - } - parameterSchema.DefaultSourceScheme = parameterSourceScheme - parameterSchema.DefaultSourceValue = protoParameter.DefaultSource.Value - } - - // It's possible a parameter doesn't define a default destination! - if protoParameter.DefaultDestination != nil { - parameterDestinationScheme, err := convertParameterDestinationScheme(protoParameter.DefaultDestination.Scheme) - if err != nil { - return nil, xerrors.Errorf("convert parameter destination scheme: %w", err) - } - parameterSchema.DefaultDestinationScheme = parameterDestinationScheme - } - - _, err = server.Database.InsertParameterSchema(ctx, parameterSchema) - if err != nil { - return nil, xerrors.Errorf("insert parameter schema: %w", err) - } - } - - var templateID uuid.NullUUID - if job.Type == database.ProvisionerJobTypeTemplateVersionImport { - templateVersion, err := server.Database.GetTemplateVersionByJobID(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get template version by job id: %w", err) - } - templateID = templateVersion.TemplateID - } - - parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{ - TemplateImportJobID: job.ID, - TemplateID: templateID, - }, nil) - if err != nil { - return nil, xerrors.Errorf("compute parameters: %w", err) - } - // Convert parameters to the protobuf type. - protoParameters := make([]*sdkproto.ParameterValue, 0, len(parameters)) - for _, computedParameter := range parameters { - converted, err := convertComputedParameterValue(computedParameter) - if err != nil { - return nil, xerrors.Errorf("convert parameter: %s", err) - } - protoParameters = append(protoParameters, converted) - } - - return &proto.UpdateJobResponse{ - Canceled: job.CanceledAt.Valid, - ParameterValues: protoParameters, - }, nil - } - - return &proto.UpdateJobResponse{ - Canceled: job.CanceledAt.Valid, - }, nil -} - -func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) { - jobID, err := uuid.Parse(failJob.JobId) - if err != nil { - return nil, xerrors.Errorf("parse job id: %w", err) - } - server.Logger.Debug(ctx, "FailJob starting", slog.F("job_id", jobID)) - job, err := server.Database.GetProvisionerJobByID(ctx, jobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - if job.CompletedAt.Valid { - return nil, xerrors.Errorf("job already completed") - } - job.CompletedAt = sql.NullTime{ - Time: database.Now(), - Valid: true, - } - job.Error = sql.NullString{ - String: failJob.Error, - Valid: failJob.Error != "", - } - - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: jobID, - CompletedAt: job.CompletedAt, - UpdatedAt: database.Now(), - Error: job.Error, - }) - if err != nil { - return nil, xerrors.Errorf("update provisioner job: %w", err) - } - server.Telemetry.Report(&telemetry.Snapshot{ - ProvisionerJobs: []telemetry.ProvisionerJob{telemetry.ConvertProvisionerJob(job)}, - }) - - switch jobType := failJob.Type.(type) { - case *proto.FailedJob_WorkspaceBuild_: - if jobType.WorkspaceBuild.State == nil { - break - } - var input workspaceProvisionJob - err = json.Unmarshal(job.Input, &input) - if err != nil { - return nil, xerrors.Errorf("unmarshal workspace provision input: %w", err) - } - build, err := server.Database.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{ - ID: input.WorkspaceBuildID, - UpdatedAt: database.Now(), - ProvisionerState: jobType.WorkspaceBuild.State, - // We are explicitly not updating deadline here. - }) - if err != nil { - return nil, xerrors.Errorf("update workspace build state: %w", err) - } - err = server.Pubsub.Publish(watchWorkspaceChannel(build.WorkspaceID), []byte{}) - if err != nil { - return nil, xerrors.Errorf("update workspace: %w", err) - } - case *proto.FailedJob_TemplateImport_: - } - - data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true}) - if err != nil { - return nil, xerrors.Errorf("marshal job log: %w", err) - } - err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data) - if err != nil { - server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) - return nil, xerrors.Errorf("publish end of job logs: %w", err) - } - return &proto.Empty{}, nil -} - -// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. -func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { - jobID, err := uuid.Parse(completed.JobId) - if err != nil { - return nil, xerrors.Errorf("parse job id: %w", err) - } - server.Logger.Debug(ctx, "CompleteJob starting", slog.F("job_id", jobID)) - job, err := server.Database.GetProvisionerJobByID(ctx, jobID) - if err != nil { - return nil, xerrors.Errorf("get job by id: %w", err) - } - if job.WorkerID.UUID.String() != server.ID.String() { - return nil, xerrors.Errorf("you don't have permission to update this job") - } - - telemetrySnapshot := &telemetry.Snapshot{} - // Items are added to this snapshot as they complete! - defer server.Telemetry.Report(telemetrySnapshot) - - switch jobType := completed.Type.(type) { - case *proto.CompletedJob_TemplateImport_: - for transition, resources := range map[database.WorkspaceTransition][]*sdkproto.Resource{ - database.WorkspaceTransitionStart: jobType.TemplateImport.StartResources, - database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources, - } { - for _, resource := range resources { - server.Logger.Info(ctx, "inserting template import job resource", - slog.F("job_id", job.ID.String()), - slog.F("resource_name", resource.Name), - slog.F("resource_type", resource.Type), - slog.F("transition", transition)) - - err = insertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot) - if err != nil { - return nil, xerrors.Errorf("insert resource: %w", err) - } - } - } - - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: jobID, - UpdatedAt: database.Now(), - CompletedAt: sql.NullTime{ - Time: database.Now(), - Valid: true, - }, - }) - if err != nil { - return nil, xerrors.Errorf("update provisioner job: %w", err) - } - server.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) - if err != nil { - return nil, xerrors.Errorf("complete job: %w", err) - } - case *proto.CompletedJob_WorkspaceBuild_: - var input workspaceProvisionJob - err = json.Unmarshal(job.Input, &input) - if err != nil { - return nil, xerrors.Errorf("unmarshal job data: %w", err) - } - - workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) - if err != nil { - return nil, xerrors.Errorf("get workspace build: %w", err) - } - - err = server.Database.InTx(func(db database.Store) error { - now := database.Now() - var workspaceDeadline time.Time - workspace, err := db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) - if err == nil { - if workspace.Ttl.Valid { - workspaceDeadline = now.Add(time.Duration(workspace.Ttl.Int64)) - } - } else { - // Huh? Did the workspace get deleted? - // In any case, since this is just for the TTL, try and continue anyway. - server.Logger.Error(ctx, "fetch workspace for build", slog.F("workspace_build_id", workspaceBuild.ID), slog.F("workspace_id", workspaceBuild.WorkspaceID)) - } - err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: jobID, - UpdatedAt: database.Now(), - CompletedAt: sql.NullTime{ - Time: database.Now(), - Valid: true, - }, - }) - if err != nil { - return xerrors.Errorf("update provisioner job: %w", err) - } - _, err = db.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{ - ID: workspaceBuild.ID, - Deadline: workspaceDeadline, - ProvisionerState: jobType.WorkspaceBuild.State, - UpdatedAt: now, - }) - if err != nil { - return xerrors.Errorf("update workspace build: %w", err) - } - // This could be a bulk insert to improve performance. - for _, protoResource := range jobType.WorkspaceBuild.Resources { - err = insertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot) - if err != nil { - return xerrors.Errorf("insert provisioner job: %w", err) - } - } - - if workspaceBuild.Transition != database.WorkspaceTransitionDelete { - // This is for deleting a workspace! - return nil - } - - err = db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ - ID: workspaceBuild.WorkspaceID, - Deleted: true, - }) - if err != nil { - return xerrors.Errorf("update workspace deleted: %w", err) - } - - return nil - }) - if err != nil { - return nil, xerrors.Errorf("complete job: %w", err) - } - - err = server.Pubsub.Publish(watchWorkspaceChannel(workspaceBuild.WorkspaceID), []byte{}) - if err != nil { - return nil, xerrors.Errorf("update workspace: %w", err) - } - case *proto.CompletedJob_TemplateDryRun_: - for _, resource := range jobType.TemplateDryRun.Resources { - server.Logger.Info(ctx, "inserting template dry-run job resource", - slog.F("job_id", job.ID.String()), - slog.F("resource_name", resource.Name), - slog.F("resource_type", resource.Type)) - - err = insertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) - if err != nil { - return nil, xerrors.Errorf("insert resource: %w", err) - } - } - - err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ - ID: jobID, - UpdatedAt: database.Now(), - CompletedAt: sql.NullTime{ - Time: database.Now(), - Valid: true, - }, - }) - if err != nil { - return nil, xerrors.Errorf("update provisioner job: %w", err) - } - server.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID)) - if err != nil { - return nil, xerrors.Errorf("complete job: %w", err) - } - - default: - return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", - reflect.TypeOf(completed.Type).String()) - } - - data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true}) - if err != nil { - return nil, xerrors.Errorf("marshal job log: %w", err) - } - err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data) - if err != nil { - server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) - return nil, xerrors.Errorf("publish end of job logs: %w", err) - } - - server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID)) - return &proto.Empty{}, nil -} - -func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoResource *sdkproto.Resource, snapshot *telemetry.Snapshot) error { - resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - JobID: jobID, - Transition: transition, - Type: protoResource.Type, - Name: protoResource.Name, - Hide: protoResource.Hide, - Icon: protoResource.Icon, - InstanceType: sql.NullString{ - String: protoResource.InstanceType, - Valid: protoResource.InstanceType != "", - }, - }) - if err != nil { - return xerrors.Errorf("insert provisioner job resource %q: %w", protoResource.Name, err) - } - snapshot.WorkspaceResources = append(snapshot.WorkspaceResources, telemetry.ConvertWorkspaceResource(resource)) - - var appSlugs = make(map[string]struct{}) - for _, prAgent := range protoResource.Agents { - var instanceID sql.NullString - if prAgent.GetInstanceId() != "" { - instanceID = sql.NullString{ - String: prAgent.GetInstanceId(), - Valid: true, - } - } - var env pqtype.NullRawMessage - if prAgent.Env != nil { - data, err := json.Marshal(prAgent.Env) - if err != nil { - return xerrors.Errorf("marshal env: %w", err) - } - env = pqtype.NullRawMessage{ - RawMessage: data, - Valid: true, - } - } - authToken := uuid.New() - if prAgent.GetToken() != "" { - authToken, err = uuid.Parse(prAgent.GetToken()) - if err != nil { - return xerrors.Errorf("invalid auth token format; must be uuid: %w", err) - } - } - - agentID := uuid.New() - dbAgent, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{ - ID: agentID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - ResourceID: resource.ID, - Name: prAgent.Name, - AuthToken: authToken, - AuthInstanceID: instanceID, - Architecture: prAgent.Architecture, - EnvironmentVariables: env, - Directory: prAgent.Directory, - OperatingSystem: prAgent.OperatingSystem, - StartupScript: sql.NullString{ - String: prAgent.StartupScript, - Valid: prAgent.StartupScript != "", - }, - }) - if err != nil { - return xerrors.Errorf("insert agent: %w", err) - } - snapshot.WorkspaceAgents = append(snapshot.WorkspaceAgents, telemetry.ConvertWorkspaceAgent(dbAgent)) - - for _, app := range prAgent.Apps { - slug := app.Slug - if slug == "" { - return xerrors.Errorf("app must have a slug or name set") - } - if !provisioner.AppSlugRegex.MatchString(slug) { - return xerrors.Errorf("app slug %q does not match regex %q", slug, provisioner.AppSlugRegex.String()) - } - if _, exists := appSlugs[slug]; exists { - return xerrors.Errorf("duplicate app slug, must be unique per template: %q", slug) - } - appSlugs[slug] = struct{}{} - - health := database.WorkspaceAppHealthDisabled - if app.Healthcheck == nil { - app.Healthcheck = &sdkproto.Healthcheck{} - } - if app.Healthcheck.Url != "" { - health = database.WorkspaceAppHealthInitializing - } - - sharingLevel := database.AppSharingLevelOwner - switch app.SharingLevel { - case sdkproto.AppSharingLevel_AUTHENTICATED: - sharingLevel = database.AppSharingLevelAuthenticated - case sdkproto.AppSharingLevel_PUBLIC: - sharingLevel = database.AppSharingLevelPublic - } - - dbApp, err := db.InsertWorkspaceApp(ctx, database.InsertWorkspaceAppParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - AgentID: dbAgent.ID, - Slug: slug, - DisplayName: app.DisplayName, - Icon: app.Icon, - Command: sql.NullString{ - String: app.Command, - Valid: app.Command != "", - }, - Url: sql.NullString{ - String: app.Url, - Valid: app.Url != "", - }, - Subdomain: app.Subdomain, - SharingLevel: sharingLevel, - HealthcheckUrl: app.Healthcheck.Url, - HealthcheckInterval: app.Healthcheck.Interval, - HealthcheckThreshold: app.Healthcheck.Threshold, - Health: health, - }) - if err != nil { - return xerrors.Errorf("insert app: %w", err) - } - snapshot.WorkspaceApps = append(snapshot.WorkspaceApps, telemetry.ConvertWorkspaceApp(dbApp)) - } - } - - for _, metadatum := range protoResource.Metadata { - var value sql.NullString - if !metadatum.IsNull { - value.String = metadatum.Value - value.Valid = true - } - - _, err := db.InsertWorkspaceResourceMetadata(ctx, database.InsertWorkspaceResourceMetadataParams{ - WorkspaceResourceID: resource.ID, - Key: metadatum.Key, - Value: value, - Sensitive: metadatum.Sensitive, - }) - if err != nil { - return xerrors.Errorf("insert metadata: %w", err) - } - } - - return nil -} - -func convertValidationTypeSystem(typeSystem sdkproto.ParameterSchema_TypeSystem) (database.ParameterTypeSystem, error) { - switch typeSystem { - case sdkproto.ParameterSchema_None: - return database.ParameterTypeSystemNone, nil - case sdkproto.ParameterSchema_HCL: - return database.ParameterTypeSystemHCL, nil - default: - return database.ParameterTypeSystem(""), xerrors.Errorf("unknown type system: %d", typeSystem) - } -} - -func convertParameterSourceScheme(sourceScheme sdkproto.ParameterSource_Scheme) (database.ParameterSourceScheme, error) { - switch sourceScheme { - case sdkproto.ParameterSource_DATA: - return database.ParameterSourceSchemeData, nil - default: - return database.ParameterSourceScheme(""), xerrors.Errorf("unknown parameter source scheme: %d", sourceScheme) - } -} - -func convertParameterDestinationScheme(destinationScheme sdkproto.ParameterDestination_Scheme) (database.ParameterDestinationScheme, error) { - switch destinationScheme { - case sdkproto.ParameterDestination_ENVIRONMENT_VARIABLE: - return database.ParameterDestinationSchemeEnvironmentVariable, nil - case sdkproto.ParameterDestination_PROVISIONER_VARIABLE: - return database.ParameterDestinationSchemeProvisionerVariable, nil - default: - return database.ParameterDestinationScheme(""), xerrors.Errorf("unknown parameter destination scheme: %d", destinationScheme) - } -} - -func convertLogLevel(logLevel sdkproto.LogLevel) (database.LogLevel, error) { - switch logLevel { - case sdkproto.LogLevel_TRACE: - return database.LogLevelTrace, nil - case sdkproto.LogLevel_DEBUG: - return database.LogLevelDebug, nil - case sdkproto.LogLevel_INFO: - return database.LogLevelInfo, nil - case sdkproto.LogLevel_WARN: - return database.LogLevelWarn, nil - case sdkproto.LogLevel_ERROR: - return database.LogLevelError, nil - default: - return database.LogLevel(""), xerrors.Errorf("unknown log level: %d", logLevel) - } -} - -func convertLogSource(logSource proto.LogSource) (database.LogSource, error) { - switch logSource { - case proto.LogSource_PROVISIONER_DAEMON: - return database.LogSourceProvisionerDaemon, nil - case proto.LogSource_PROVISIONER: - return database.LogSourceProvisioner, nil - default: - return database.LogSource(""), xerrors.Errorf("unknown log source: %d", logSource) - } -} - -func convertComputedParameterValues(parameters []parameter.ComputedValue) ([]*sdkproto.ParameterValue, error) { - protoParameters := make([]*sdkproto.ParameterValue, len(parameters)) - for i, computedParameter := range parameters { - converted, err := convertComputedParameterValue(computedParameter) - if err != nil { - return nil, xerrors.Errorf("convert parameter: %w", err) - } - protoParameters[i] = converted - } - - return protoParameters, nil -} - -func convertComputedParameterValue(param parameter.ComputedValue) (*sdkproto.ParameterValue, error) { - var scheme sdkproto.ParameterDestination_Scheme - switch param.DestinationScheme { - case database.ParameterDestinationSchemeEnvironmentVariable: - scheme = sdkproto.ParameterDestination_ENVIRONMENT_VARIABLE - case database.ParameterDestinationSchemeProvisionerVariable: - scheme = sdkproto.ParameterDestination_PROVISIONER_VARIABLE - default: - return nil, xerrors.Errorf("unrecognized destination scheme: %q", param.DestinationScheme) - } - - return &sdkproto.ParameterValue{ - DestinationScheme: scheme, - Name: param.Name, - Value: param.SourceValue, - }, nil -} - -func convertWorkspaceTransition(transition database.WorkspaceTransition) (sdkproto.WorkspaceTransition, error) { - switch transition { - case database.WorkspaceTransitionStart: - return sdkproto.WorkspaceTransition_START, nil - case database.WorkspaceTransitionStop: - return sdkproto.WorkspaceTransition_STOP, nil - case database.WorkspaceTransitionDelete: - return sdkproto.WorkspaceTransition_DESTROY, nil - default: - return 0, xerrors.Errorf("unrecognized transition: %q", transition) - } -} diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go new file mode 100644 index 0000000000000..b8a6ad91386e9 --- /dev/null +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -0,0 +1,932 @@ +package provisionerdserver + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/url" + "reflect" + "time" + + "github.com/google/uuid" + "github.com/tabbed/pqtype" + "golang.org/x/xerrors" + protobuf "google.golang.org/protobuf/proto" + + "cdr.dev/slog" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/parameter" + "github.com/coder/coder/coderd/telemetry" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisioner" + "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +type Server struct { + AccessURL *url.URL + ID uuid.UUID + Logger slog.Logger + Provisioners []database.ProvisionerType + Database database.Store + Pubsub database.Pubsub + Telemetry telemetry.Reporter +} + +// AcquireJob queries the database to lock a job. +func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + // This marks the job as locked in the database. + job, err := server.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + StartedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + WorkerID: uuid.NullUUID{ + UUID: server.ID, + Valid: true, + }, + Types: server.Provisioners, + }) + if errors.Is(err, sql.ErrNoRows) { + // The provisioner daemon assumes no jobs are available if + // an empty struct is returned. + return &proto.AcquiredJob{}, nil + } + if err != nil { + return nil, xerrors.Errorf("acquire job: %w", err) + } + server.Logger.Debug(ctx, "locked job from database", slog.F("id", job.ID)) + + // Marks the acquired job as failed with the error message provided. + failJob := func(errorMessage string) error { + err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: job.ID, + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + Error: sql.NullString{ + String: errorMessage, + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + return xerrors.Errorf("request job was invalidated: %s", errorMessage) + } + + user, err := server.Database.GetUserByID(ctx, job.InitiatorID) + if err != nil { + return nil, failJob(fmt.Sprintf("get user: %s", err)) + } + + protoJob := &proto.AcquiredJob{ + JobId: job.ID.String(), + CreatedAt: job.CreatedAt.UnixMilli(), + Provisioner: string(job.Provisioner), + UserName: user.Username, + } + switch job.Type { + case database.ProvisionerJobTypeWorkspaceBuild: + var input WorkspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) + } + workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace build: %s", err)) + } + workspace, err := server.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace: %s", err)) + } + templateVersion, err := server.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID) + if err != nil { + return nil, failJob(fmt.Sprintf("get template version: %s", err)) + } + template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) + if err != nil { + return nil, failJob(fmt.Sprintf("get template: %s", err)) + } + owner, err := server.Database.GetUserByID(ctx, workspace.OwnerID) + if err != nil { + return nil, failJob(fmt.Sprintf("get owner: %s", err)) + } + err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{}) + if err != nil { + return nil, failJob(fmt.Sprintf("publish workspace update: %s", err)) + } + + // Compute parameters for the workspace to consume. + parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{ + TemplateImportJobID: templateVersion.JobID, + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + WorkspaceID: uuid.NullUUID{ + UUID: workspace.ID, + Valid: true, + }, + }, nil) + if err != nil { + return nil, failJob(fmt.Sprintf("compute parameters: %s", err)) + } + + // Convert types to their corresponding protobuf types. + protoParameters, err := convertComputedParameterValues(parameters) + if err != nil { + return nil, failJob(fmt.Sprintf("convert computed parameters to protobuf: %s", err)) + } + transition, err := convertWorkspaceTransition(workspaceBuild.Transition) + if err != nil { + return nil, failJob(fmt.Sprintf("convert workspace transition: %s", err)) + } + + protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ + WorkspaceBuildId: workspaceBuild.ID.String(), + WorkspaceName: workspace.Name, + State: workspaceBuild.ProvisionerState, + ParameterValues: protoParameters, + Metadata: &sdkproto.Provision_Metadata{ + CoderUrl: server.AccessURL.String(), + WorkspaceTransition: transition, + WorkspaceName: workspace.Name, + WorkspaceOwner: owner.Username, + WorkspaceOwnerEmail: owner.Email, + WorkspaceId: workspace.ID.String(), + WorkspaceOwnerId: owner.ID.String(), + }, + }, + } + case database.ProvisionerJobTypeTemplateVersionDryRun: + var input TemplateVersionDryRunJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) + } + + templateVersion, err := server.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID) + if err != nil { + return nil, failJob(fmt.Sprintf("get template version: %s", err)) + } + + // Compute parameters for the dry-run to consume. + parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{ + TemplateImportJobID: templateVersion.JobID, + TemplateID: templateVersion.TemplateID, + WorkspaceID: uuid.NullUUID{}, + AdditionalParameterValues: input.ParameterValues, + }, nil) + if err != nil { + return nil, failJob(fmt.Sprintf("compute parameters: %s", err)) + } + + // Convert types to their corresponding protobuf types. + protoParameters, err := convertComputedParameterValues(parameters) + if err != nil { + return nil, failJob(fmt.Sprintf("convert computed parameters to protobuf: %s", err)) + } + + protoJob.Type = &proto.AcquiredJob_TemplateDryRun_{ + TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{ + ParameterValues: protoParameters, + Metadata: &sdkproto.Provision_Metadata{ + CoderUrl: server.AccessURL.String(), + WorkspaceName: input.WorkspaceName, + }, + }, + } + case database.ProvisionerJobTypeTemplateVersionImport: + protoJob.Type = &proto.AcquiredJob_TemplateImport_{ + TemplateImport: &proto.AcquiredJob_TemplateImport{ + Metadata: &sdkproto.Provision_Metadata{ + CoderUrl: server.AccessURL.String(), + }, + }, + } + } + switch job.StorageMethod { + case database.ProvisionerStorageMethodFile: + file, err := server.Database.GetFileByID(ctx, job.FileID) + if err != nil { + return nil, failJob(fmt.Sprintf("get file by hash: %s", err)) + } + protoJob.TemplateSourceArchive = file.Data + default: + return nil, failJob(fmt.Sprintf("unsupported storage method: %s", job.StorageMethod)) + } + if protobuf.Size(protoJob) > provisionersdk.MaxMessageSize { + return nil, failJob(fmt.Sprintf("payload was too big: %d > %d", protobuf.Size(protoJob), provisionersdk.MaxMessageSize)) + } + + return protoJob, err +} + +func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + parsedID, err := uuid.Parse(request.JobId) + if err != nil { + return nil, xerrors.Errorf("parse job id: %w", err) + } + server.Logger.Debug(ctx, "UpdateJob starting", slog.F("job_id", parsedID)) + job, err := server.Database.GetProvisionerJobByID(ctx, parsedID) + if err != nil { + return nil, xerrors.Errorf("get job: %w", err) + } + if !job.WorkerID.Valid { + return nil, xerrors.New("job isn't running yet") + } + if job.WorkerID.UUID.String() != server.ID.String() { + return nil, xerrors.New("you don't own this job") + } + err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: parsedID, + UpdatedAt: database.Now(), + }) + if err != nil { + return nil, xerrors.Errorf("update job: %w", err) + } + + if len(request.Logs) > 0 { + insertParams := database.InsertProvisionerJobLogsParams{ + JobID: parsedID, + } + for _, log := range request.Logs { + logLevel, err := convertLogLevel(log.Level) + if err != nil { + return nil, xerrors.Errorf("convert log level: %w", err) + } + logSource, err := convertLogSource(log.Source) + if err != nil { + return nil, xerrors.Errorf("convert log source: %w", err) + } + insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt)) + insertParams.Level = append(insertParams.Level, logLevel) + insertParams.Stage = append(insertParams.Stage, log.Stage) + insertParams.Source = append(insertParams.Source, logSource) + insertParams.Output = append(insertParams.Output, log.Output) + server.Logger.Debug(ctx, "job log", + slog.F("job_id", parsedID), + slog.F("stage", log.Stage), + slog.F("output", log.Output)) + } + logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams) + if err != nil { + server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) + return nil, xerrors.Errorf("insert job logs: %w", err) + } + // Publish by the lowest log ID inserted so the + // log stream will fetch everything from that point. + lowestID := logs[0].ID + server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) + data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{ + CreatedAfter: lowestID, + }) + if err != nil { + return nil, xerrors.Errorf("marshal: %w", err) + } + err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(parsedID), data) + if err != nil { + server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) + return nil, xerrors.Errorf("publish job log: %w", err) + } + server.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID)) + } + + if len(request.Readme) > 0 { + err := server.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{ + JobID: job.ID, + Readme: string(request.Readme), + UpdatedAt: database.Now(), + }) + if err != nil { + return nil, xerrors.Errorf("update template version description: %w", err) + } + } + + if len(request.ParameterSchemas) > 0 { + for index, protoParameter := range request.ParameterSchemas { + validationTypeSystem, err := convertValidationTypeSystem(protoParameter.ValidationTypeSystem) + if err != nil { + return nil, xerrors.Errorf("convert validation type system for %q: %w", protoParameter.Name, err) + } + + parameterSchema := database.InsertParameterSchemaParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + JobID: job.ID, + Name: protoParameter.Name, + Description: protoParameter.Description, + RedisplayValue: protoParameter.RedisplayValue, + ValidationError: protoParameter.ValidationError, + ValidationCondition: protoParameter.ValidationCondition, + ValidationValueType: protoParameter.ValidationValueType, + ValidationTypeSystem: validationTypeSystem, + + DefaultSourceScheme: database.ParameterSourceSchemeNone, + DefaultDestinationScheme: database.ParameterDestinationSchemeNone, + + AllowOverrideDestination: protoParameter.AllowOverrideDestination, + AllowOverrideSource: protoParameter.AllowOverrideSource, + + Index: int32(index), + } + + // It's possible a parameter doesn't define a default source! + if protoParameter.DefaultSource != nil { + parameterSourceScheme, err := convertParameterSourceScheme(protoParameter.DefaultSource.Scheme) + if err != nil { + return nil, xerrors.Errorf("convert parameter source scheme: %w", err) + } + parameterSchema.DefaultSourceScheme = parameterSourceScheme + parameterSchema.DefaultSourceValue = protoParameter.DefaultSource.Value + } + + // It's possible a parameter doesn't define a default destination! + if protoParameter.DefaultDestination != nil { + parameterDestinationScheme, err := convertParameterDestinationScheme(protoParameter.DefaultDestination.Scheme) + if err != nil { + return nil, xerrors.Errorf("convert parameter destination scheme: %w", err) + } + parameterSchema.DefaultDestinationScheme = parameterDestinationScheme + } + + _, err = server.Database.InsertParameterSchema(ctx, parameterSchema) + if err != nil { + return nil, xerrors.Errorf("insert parameter schema: %w", err) + } + } + + var templateID uuid.NullUUID + if job.Type == database.ProvisionerJobTypeTemplateVersionImport { + templateVersion, err := server.Database.GetTemplateVersionByJobID(ctx, job.ID) + if err != nil { + return nil, xerrors.Errorf("get template version by job id: %w", err) + } + templateID = templateVersion.TemplateID + } + + parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{ + TemplateImportJobID: job.ID, + TemplateID: templateID, + }, nil) + if err != nil { + return nil, xerrors.Errorf("compute parameters: %w", err) + } + // Convert parameters to the protobuf type. + protoParameters := make([]*sdkproto.ParameterValue, 0, len(parameters)) + for _, computedParameter := range parameters { + converted, err := convertComputedParameterValue(computedParameter) + if err != nil { + return nil, xerrors.Errorf("convert parameter: %s", err) + } + protoParameters = append(protoParameters, converted) + } + + return &proto.UpdateJobResponse{ + Canceled: job.CanceledAt.Valid, + ParameterValues: protoParameters, + }, nil + } + + return &proto.UpdateJobResponse{ + Canceled: job.CanceledAt.Valid, + }, nil +} + +func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) { + jobID, err := uuid.Parse(failJob.JobId) + if err != nil { + return nil, xerrors.Errorf("parse job id: %w", err) + } + server.Logger.Debug(ctx, "FailJob starting", slog.F("job_id", jobID)) + job, err := server.Database.GetProvisionerJobByID(ctx, jobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } + if job.CompletedAt.Valid { + return nil, xerrors.Errorf("job already completed") + } + job.CompletedAt = sql.NullTime{ + Time: database.Now(), + Valid: true, + } + job.Error = sql.NullString{ + String: failJob.Error, + Valid: failJob.Error != "", + } + + err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: jobID, + CompletedAt: job.CompletedAt, + UpdatedAt: database.Now(), + Error: job.Error, + }) + if err != nil { + return nil, xerrors.Errorf("update provisioner job: %w", err) + } + server.Telemetry.Report(&telemetry.Snapshot{ + ProvisionerJobs: []telemetry.ProvisionerJob{telemetry.ConvertProvisionerJob(job)}, + }) + + switch jobType := failJob.Type.(type) { + case *proto.FailedJob_WorkspaceBuild_: + if jobType.WorkspaceBuild.State == nil { + break + } + var input WorkspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, xerrors.Errorf("unmarshal workspace provision input: %w", err) + } + build, err := server.Database.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{ + ID: input.WorkspaceBuildID, + UpdatedAt: database.Now(), + ProvisionerState: jobType.WorkspaceBuild.State, + // We are explicitly not updating deadline here. + }) + if err != nil { + return nil, xerrors.Errorf("update workspace build state: %w", err) + } + err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{}) + if err != nil { + return nil, xerrors.Errorf("update workspace: %w", err) + } + case *proto.FailedJob_TemplateImport_: + } + + data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) + if err != nil { + return nil, xerrors.Errorf("marshal job log: %w", err) + } + err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data) + if err != nil { + server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + return nil, xerrors.Errorf("publish end of job logs: %w", err) + } + return &proto.Empty{}, nil +} + +// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. +func (server *Server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { + jobID, err := uuid.Parse(completed.JobId) + if err != nil { + return nil, xerrors.Errorf("parse job id: %w", err) + } + server.Logger.Debug(ctx, "CompleteJob starting", slog.F("job_id", jobID)) + job, err := server.Database.GetProvisionerJobByID(ctx, jobID) + if err != nil { + return nil, xerrors.Errorf("get job by id: %w", err) + } + if job.WorkerID.UUID.String() != server.ID.String() { + return nil, xerrors.Errorf("you don't have permission to update this job") + } + + telemetrySnapshot := &telemetry.Snapshot{} + // Items are added to this snapshot as they complete! + defer server.Telemetry.Report(telemetrySnapshot) + + switch jobType := completed.Type.(type) { + case *proto.CompletedJob_TemplateImport_: + for transition, resources := range map[database.WorkspaceTransition][]*sdkproto.Resource{ + database.WorkspaceTransitionStart: jobType.TemplateImport.StartResources, + database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources, + } { + for _, resource := range resources { + server.Logger.Info(ctx, "inserting template import job resource", + slog.F("job_id", job.ID.String()), + slog.F("resource_name", resource.Name), + slog.F("resource_type", resource.Type), + slog.F("transition", transition)) + + err = insertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot) + if err != nil { + return nil, xerrors.Errorf("insert resource: %w", err) + } + } + } + + err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: jobID, + UpdatedAt: database.Now(), + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return nil, xerrors.Errorf("update provisioner job: %w", err) + } + server.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) + if err != nil { + return nil, xerrors.Errorf("complete job: %w", err) + } + case *proto.CompletedJob_WorkspaceBuild_: + var input WorkspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, xerrors.Errorf("unmarshal job data: %w", err) + } + + workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) + if err != nil { + return nil, xerrors.Errorf("get workspace build: %w", err) + } + + err = server.Database.InTx(func(db database.Store) error { + now := database.Now() + var workspaceDeadline time.Time + workspace, err := db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) + if err == nil { + if workspace.Ttl.Valid { + workspaceDeadline = now.Add(time.Duration(workspace.Ttl.Int64)) + } + } else { + // Huh? Did the workspace get deleted? + // In any case, since this is just for the TTL, try and continue anyway. + server.Logger.Error(ctx, "fetch workspace for build", slog.F("workspace_build_id", workspaceBuild.ID), slog.F("workspace_id", workspaceBuild.WorkspaceID)) + } + err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: jobID, + UpdatedAt: database.Now(), + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + _, err = db.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{ + ID: workspaceBuild.ID, + Deadline: workspaceDeadline, + ProvisionerState: jobType.WorkspaceBuild.State, + UpdatedAt: now, + }) + if err != nil { + return xerrors.Errorf("update workspace build: %w", err) + } + // This could be a bulk insert to improve performance. + for _, protoResource := range jobType.WorkspaceBuild.Resources { + err = insertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot) + if err != nil { + return xerrors.Errorf("insert provisioner job: %w", err) + } + } + + if workspaceBuild.Transition != database.WorkspaceTransitionDelete { + // This is for deleting a workspace! + return nil + } + + err = db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ + ID: workspaceBuild.WorkspaceID, + Deleted: true, + }) + if err != nil { + return xerrors.Errorf("update workspace deleted: %w", err) + } + + return nil + }) + if err != nil { + return nil, xerrors.Errorf("complete job: %w", err) + } + + err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}) + if err != nil { + return nil, xerrors.Errorf("update workspace: %w", err) + } + case *proto.CompletedJob_TemplateDryRun_: + for _, resource := range jobType.TemplateDryRun.Resources { + server.Logger.Info(ctx, "inserting template dry-run job resource", + slog.F("job_id", job.ID.String()), + slog.F("resource_name", resource.Name), + slog.F("resource_type", resource.Type)) + + err = insertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) + if err != nil { + return nil, xerrors.Errorf("insert resource: %w", err) + } + } + + err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: jobID, + UpdatedAt: database.Now(), + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return nil, xerrors.Errorf("update provisioner job: %w", err) + } + server.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID)) + if err != nil { + return nil, xerrors.Errorf("complete job: %w", err) + } + + default: + return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", + reflect.TypeOf(completed.Type).String()) + } + + data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) + if err != nil { + return nil, xerrors.Errorf("marshal job log: %w", err) + } + err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data) + if err != nil { + server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + return nil, xerrors.Errorf("publish end of job logs: %w", err) + } + + server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID)) + return &proto.Empty{}, nil +} + +func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoResource *sdkproto.Resource, snapshot *telemetry.Snapshot) error { + resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + JobID: jobID, + Transition: transition, + Type: protoResource.Type, + Name: protoResource.Name, + Hide: protoResource.Hide, + Icon: protoResource.Icon, + InstanceType: sql.NullString{ + String: protoResource.InstanceType, + Valid: protoResource.InstanceType != "", + }, + }) + if err != nil { + return xerrors.Errorf("insert provisioner job resource %q: %w", protoResource.Name, err) + } + snapshot.WorkspaceResources = append(snapshot.WorkspaceResources, telemetry.ConvertWorkspaceResource(resource)) + + var appSlugs = make(map[string]struct{}) + for _, prAgent := range protoResource.Agents { + var instanceID sql.NullString + if prAgent.GetInstanceId() != "" { + instanceID = sql.NullString{ + String: prAgent.GetInstanceId(), + Valid: true, + } + } + var env pqtype.NullRawMessage + if prAgent.Env != nil { + data, err := json.Marshal(prAgent.Env) + if err != nil { + return xerrors.Errorf("marshal env: %w", err) + } + env = pqtype.NullRawMessage{ + RawMessage: data, + Valid: true, + } + } + authToken := uuid.New() + if prAgent.GetToken() != "" { + authToken, err = uuid.Parse(prAgent.GetToken()) + if err != nil { + return xerrors.Errorf("invalid auth token format; must be uuid: %w", err) + } + } + + agentID := uuid.New() + dbAgent, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{ + ID: agentID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + ResourceID: resource.ID, + Name: prAgent.Name, + AuthToken: authToken, + AuthInstanceID: instanceID, + Architecture: prAgent.Architecture, + EnvironmentVariables: env, + Directory: prAgent.Directory, + OperatingSystem: prAgent.OperatingSystem, + StartupScript: sql.NullString{ + String: prAgent.StartupScript, + Valid: prAgent.StartupScript != "", + }, + }) + if err != nil { + return xerrors.Errorf("insert agent: %w", err) + } + snapshot.WorkspaceAgents = append(snapshot.WorkspaceAgents, telemetry.ConvertWorkspaceAgent(dbAgent)) + + for _, app := range prAgent.Apps { + slug := app.Slug + if slug == "" { + return xerrors.Errorf("app must have a slug or name set") + } + if !provisioner.AppSlugRegex.MatchString(slug) { + return xerrors.Errorf("app slug %q does not match regex %q", slug, provisioner.AppSlugRegex.String()) + } + if _, exists := appSlugs[slug]; exists { + return xerrors.Errorf("duplicate app slug, must be unique per template: %q", slug) + } + appSlugs[slug] = struct{}{} + + health := database.WorkspaceAppHealthDisabled + if app.Healthcheck == nil { + app.Healthcheck = &sdkproto.Healthcheck{} + } + if app.Healthcheck.Url != "" { + health = database.WorkspaceAppHealthInitializing + } + + sharingLevel := database.AppSharingLevelOwner + switch app.SharingLevel { + case sdkproto.AppSharingLevel_AUTHENTICATED: + sharingLevel = database.AppSharingLevelAuthenticated + case sdkproto.AppSharingLevel_PUBLIC: + sharingLevel = database.AppSharingLevelPublic + } + + dbApp, err := db.InsertWorkspaceApp(ctx, database.InsertWorkspaceAppParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + AgentID: dbAgent.ID, + Slug: slug, + DisplayName: app.DisplayName, + Icon: app.Icon, + Command: sql.NullString{ + String: app.Command, + Valid: app.Command != "", + }, + Url: sql.NullString{ + String: app.Url, + Valid: app.Url != "", + }, + Subdomain: app.Subdomain, + SharingLevel: sharingLevel, + HealthcheckUrl: app.Healthcheck.Url, + HealthcheckInterval: app.Healthcheck.Interval, + HealthcheckThreshold: app.Healthcheck.Threshold, + Health: health, + }) + if err != nil { + return xerrors.Errorf("insert app: %w", err) + } + snapshot.WorkspaceApps = append(snapshot.WorkspaceApps, telemetry.ConvertWorkspaceApp(dbApp)) + } + } + + for _, metadatum := range protoResource.Metadata { + var value sql.NullString + if !metadatum.IsNull { + value.String = metadatum.Value + value.Valid = true + } + + _, err := db.InsertWorkspaceResourceMetadata(ctx, database.InsertWorkspaceResourceMetadataParams{ + WorkspaceResourceID: resource.ID, + Key: metadatum.Key, + Value: value, + Sensitive: metadatum.Sensitive, + }) + if err != nil { + return xerrors.Errorf("insert metadata: %w", err) + } + } + + return nil +} + +func convertValidationTypeSystem(typeSystem sdkproto.ParameterSchema_TypeSystem) (database.ParameterTypeSystem, error) { + switch typeSystem { + case sdkproto.ParameterSchema_None: + return database.ParameterTypeSystemNone, nil + case sdkproto.ParameterSchema_HCL: + return database.ParameterTypeSystemHCL, nil + default: + return database.ParameterTypeSystem(""), xerrors.Errorf("unknown type system: %d", typeSystem) + } +} + +func convertParameterSourceScheme(sourceScheme sdkproto.ParameterSource_Scheme) (database.ParameterSourceScheme, error) { + switch sourceScheme { + case sdkproto.ParameterSource_DATA: + return database.ParameterSourceSchemeData, nil + default: + return database.ParameterSourceScheme(""), xerrors.Errorf("unknown parameter source scheme: %d", sourceScheme) + } +} + +func convertParameterDestinationScheme(destinationScheme sdkproto.ParameterDestination_Scheme) (database.ParameterDestinationScheme, error) { + switch destinationScheme { + case sdkproto.ParameterDestination_ENVIRONMENT_VARIABLE: + return database.ParameterDestinationSchemeEnvironmentVariable, nil + case sdkproto.ParameterDestination_PROVISIONER_VARIABLE: + return database.ParameterDestinationSchemeProvisionerVariable, nil + default: + return database.ParameterDestinationScheme(""), xerrors.Errorf("unknown parameter destination scheme: %d", destinationScheme) + } +} + +func convertLogLevel(logLevel sdkproto.LogLevel) (database.LogLevel, error) { + switch logLevel { + case sdkproto.LogLevel_TRACE: + return database.LogLevelTrace, nil + case sdkproto.LogLevel_DEBUG: + return database.LogLevelDebug, nil + case sdkproto.LogLevel_INFO: + return database.LogLevelInfo, nil + case sdkproto.LogLevel_WARN: + return database.LogLevelWarn, nil + case sdkproto.LogLevel_ERROR: + return database.LogLevelError, nil + default: + return database.LogLevel(""), xerrors.Errorf("unknown log level: %d", logLevel) + } +} + +func convertLogSource(logSource proto.LogSource) (database.LogSource, error) { + switch logSource { + case proto.LogSource_PROVISIONER_DAEMON: + return database.LogSourceProvisionerDaemon, nil + case proto.LogSource_PROVISIONER: + return database.LogSourceProvisioner, nil + default: + return database.LogSource(""), xerrors.Errorf("unknown log source: %d", logSource) + } +} + +func convertComputedParameterValues(parameters []parameter.ComputedValue) ([]*sdkproto.ParameterValue, error) { + protoParameters := make([]*sdkproto.ParameterValue, len(parameters)) + for i, computedParameter := range parameters { + converted, err := convertComputedParameterValue(computedParameter) + if err != nil { + return nil, xerrors.Errorf("convert parameter: %w", err) + } + protoParameters[i] = converted + } + + return protoParameters, nil +} + +func convertComputedParameterValue(param parameter.ComputedValue) (*sdkproto.ParameterValue, error) { + var scheme sdkproto.ParameterDestination_Scheme + switch param.DestinationScheme { + case database.ParameterDestinationSchemeEnvironmentVariable: + scheme = sdkproto.ParameterDestination_ENVIRONMENT_VARIABLE + case database.ParameterDestinationSchemeProvisionerVariable: + scheme = sdkproto.ParameterDestination_PROVISIONER_VARIABLE + default: + return nil, xerrors.Errorf("unrecognized destination scheme: %q", param.DestinationScheme) + } + + return &sdkproto.ParameterValue{ + DestinationScheme: scheme, + Name: param.Name, + Value: param.SourceValue, + }, nil +} + +func convertWorkspaceTransition(transition database.WorkspaceTransition) (sdkproto.WorkspaceTransition, error) { + switch transition { + case database.WorkspaceTransitionStart: + return sdkproto.WorkspaceTransition_START, nil + case database.WorkspaceTransitionStop: + return sdkproto.WorkspaceTransition_STOP, nil + case database.WorkspaceTransitionDelete: + return sdkproto.WorkspaceTransition_DESTROY, nil + default: + return 0, xerrors.Errorf("unrecognized transition: %q", transition) + } +} + +// WorkspaceProvisionJob is the payload for the "workspace_provision" job type. +type WorkspaceProvisionJob struct { + WorkspaceBuildID uuid.UUID `json:"workspace_build_id"` + DryRun bool `json:"dry_run"` +} + +// TemplateVersionDryRunJob is the payload for the "template_version_dry_run" job type. +type TemplateVersionDryRunJob struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + WorkspaceName string `json:"workspace_name"` + ParameterValues []database.ParameterValue `json:"parameter_values"` +} + +// ProvisionerJobLogsNotifyMessage is the payload published on +// the provisioner job logs notify channel. +type ProvisionerJobLogsNotifyMessage struct { + CreatedAfter int64 `json:"created_after"` + EndOfLogs bool `json:"end_of_logs,omitempty"` +} + +// ProvisionerJobLogsNotifyChannel is the PostgreSQL NOTIFY channel +// to publish updates to job logs on. +func ProvisionerJobLogsNotifyChannel(jobID uuid.UUID) string { + return fmt.Sprintf("provisioner-log-logs:%s", jobID) +} diff --git a/coderd/templateversions.go b/coderd/templateversions.go index bc0a3c91bf7fe..1d3ccb7919a23 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -17,6 +17,7 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/parameter" + "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" ) @@ -261,7 +262,7 @@ func (api *API) postTemplateVersionDryRun(rw http.ResponseWriter, r *http.Reques // Marshal template version dry-run job with the parameters from the // request. - input, err := json.Marshal(templateVersionDryRunJob{ + input, err := json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ TemplateVersionID: templateVersion.ID, WorkspaceName: req.WorkspaceName, ParameterValues: parameterValues, @@ -428,7 +429,7 @@ func (api *API) fetchTemplateVersionDryRunJob(rw http.ResponseWriter, r *http.Re } // Verify that the template version is the one used in the request. - var input templateVersionDryRunJob + var input provisionerdserver.TemplateVersionDryRunJob err = json.Unmarshal(job.Input, &input) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 2373556cb8417..54f8d2a9affaa 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -20,6 +20,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" ) @@ -495,7 +496,7 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { } workspaceBuildID := uuid.New() - input, err := json.Marshal(workspaceProvisionJob{ + input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: workspaceBuildID, }) if err != nil { diff --git a/coderd/workspaceresourceauth.go b/coderd/workspaceresourceauth.go index 3d8fc0b281fd5..84923fd33db00 100644 --- a/coderd/workspaceresourceauth.go +++ b/coderd/workspaceresourceauth.go @@ -11,6 +11,7 @@ import ( "github.com/coder/coder/coderd/azureidentity" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/codersdk" "github.com/mitchellh/mapstructure" @@ -130,7 +131,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in }) return } - var jobData workspaceProvisionJob + var jobData provisionerdserver.WorkspaceProvisionJob err = json.Unmarshal(job.Input, &jobData) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 53e56be63fda4..2a18991e0fefc 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -25,6 +25,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" @@ -472,7 +473,7 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req } } - input, err := json.Marshal(workspaceProvisionJob{ + input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: workspaceBuildID, }) if err != nil { diff --git a/codersdk/workspaces.go b/codersdk/workspaces.go index bf8279a016d87..6e217feed4531 100644 --- a/codersdk/workspaces.go +++ b/codersdk/workspaces.go @@ -397,3 +397,10 @@ func (c *Client) GetAppHost(ctx context.Context) (GetAppHostResponse, error) { var host GetAppHostResponse return host, json.NewDecoder(res.Body).Decode(&host) } + +// WorkspaceNotifyChannel is the PostgreSQL NOTIFY +// channel to listen for updates on. The payload is empty, +// because the size of a workspace payload can be very large. +func WorkspaceNotifyChannel(id uuid.UUID) string { + return fmt.Sprintf("workspace:%s", id) +} From a59a7ed166c00e34a1a402d88fe57c63c4466bc0 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 7 Nov 2022 23:04:13 +0000 Subject: [PATCH 2/5] Add tests --- .../provisionerdserver/provisionerdserver.go | 3 + .../provisionerdserver_test.go | 460 ++++++++++++++++++ 2 files changed, 463 insertions(+) create mode 100644 coderd/provisionerdserver/provisionerdserver_test.go diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index b8a6ad91386e9..b7dfe0360fe39 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -411,6 +411,9 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p if err != nil { return nil, xerrors.Errorf("get provisioner job: %w", err) } + if job.WorkerID.UUID.String() != server.ID.String() { + return nil, xerrors.New("you don't own this job") + } if job.CompletedAt.Valid { return nil, xerrors.Errorf("job already completed") } diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go new file mode 100644 index 0000000000000..16fa17344f3ed --- /dev/null +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -0,0 +1,460 @@ +package provisionerdserver_test + +import ( + "context" + "database/sql" + "encoding/json" + "net/url" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/coderd/provisionerdserver" + "github.com/coder/coder/coderd/telemetry" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +func TestAcquireJob(t *testing.T) { + t.Parallel() + t.Run("NoJobs", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.AcquireJob(context.Background(), nil) + require.NoError(t, err) + require.Equal(t, &proto.AcquiredJob{}, job) + }) + t.Run("InitiatorNotFound", func(t *testing.T) { + t.Parallel() + srv := setup(t) + _, err := srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + ID: uuid.New(), + InitiatorID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.AcquireJob(context.Background(), nil) + require.ErrorContains(t, err, "sql: no rows in result set") + }) + t.Run("WorkspaceBuildJob", func(t *testing.T) { + t.Parallel() + srv := setup(t) + ctx := context.Background() + user, err := srv.Database.InsertUser(context.Background(), database.InsertUserParams{ + ID: uuid.New(), + Username: "testing", + }) + require.NoError(t, err) + template, err := srv.Database.InsertTemplate(ctx, database.InsertTemplateParams{ + ID: uuid.New(), + Name: "template", + }) + require.NoError(t, err) + version, err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + ID: uuid.New(), + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + JobID: uuid.New(), + }) + require.NoError(t, err) + workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + ID: uuid.New(), + OwnerID: user.ID, + TemplateID: template.ID, + Name: "workspace", + }) + require.NoError(t, err) + build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + ID: uuid.New(), + WorkspaceID: workspace.ID, + BuildNumber: 1, + JobID: uuid.New(), + TemplateVersionID: version.ID, + Transition: database.WorkspaceTransitionStart, + }) + require.NoError(t, err) + + data, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ + WorkspaceBuildID: build.ID, + }) + require.NoError(t, err) + + file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ + ID: uuid.New(), + Hash: "something", + Data: []byte{}, + }) + require.NoError(t, err) + + _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + ID: build.JobID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + OrganizationID: uuid.New(), + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: data, + }) + require.NoError(t, err) + + published := make(chan struct{}) + closeSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { + close(published) + }) + require.NoError(t, err) + defer closeSubscribe() + + job, err := srv.AcquireJob(ctx, nil) + require.NoError(t, err) + + <-published + + got, err := json.Marshal(job.Type) + require.NoError(t, err) + + want, err := json.Marshal(&proto.AcquiredJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ + WorkspaceBuildId: build.ID.String(), + WorkspaceName: workspace.Name, + ParameterValues: []*sdkproto.ParameterValue{}, + Metadata: &sdkproto.Provision_Metadata{ + CoderUrl: srv.AccessURL.String(), + WorkspaceTransition: sdkproto.WorkspaceTransition_START, + WorkspaceName: workspace.Name, + WorkspaceOwner: user.Username, + WorkspaceOwnerEmail: user.Email, + WorkspaceId: workspace.ID.String(), + WorkspaceOwnerId: user.ID.String(), + }, + }, + }) + require.NoError(t, err) + + require.JSONEq(t, string(want), string(got)) + }) + t.Run("TemplateVersionDryRun", func(t *testing.T) { + t.Parallel() + srv := setup(t) + ctx := context.Background() + user, err := srv.Database.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.New(), + Username: "testing", + }) + require.NoError(t, err) + version, err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + ID: uuid.New(), + }) + require.NoError(t, err) + + data, err := json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + WorkspaceName: "testing", + ParameterValues: []database.ParameterValue{}, + }) + require.NoError(t, err) + + file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ + ID: uuid.New(), + Hash: "something", + Data: []byte{}, + }) + require.NoError(t, err) + + _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + OrganizationID: uuid.New(), + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: data, + }) + require.NoError(t, err) + + job, err := srv.AcquireJob(ctx, nil) + require.NoError(t, err) + + got, err := json.Marshal(job.Type) + require.NoError(t, err) + + want, err := json.Marshal(&proto.AcquiredJob_TemplateDryRun_{ + TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{ + ParameterValues: []*sdkproto.ParameterValue{}, + Metadata: &sdkproto.Provision_Metadata{ + CoderUrl: srv.AccessURL.String(), + WorkspaceName: "testing", + }, + }, + }) + require.NoError(t, err) + require.JSONEq(t, string(want), string(got)) + }) + t.Run("TemplateVersionImport", func(t *testing.T) { + t.Parallel() + srv := setup(t) + ctx := context.Background() + user, err := srv.Database.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.New(), + Username: "testing", + }) + require.NoError(t, err) + + file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ + ID: uuid.New(), + Hash: "something", + Data: []byte{}, + }) + require.NoError(t, err) + + _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + OrganizationID: uuid.New(), + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Input: json.RawMessage{}, + }) + require.NoError(t, err) + + job, err := srv.AcquireJob(ctx, nil) + require.NoError(t, err) + + got, err := json.Marshal(job.Type) + require.NoError(t, err) + + want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{ + TemplateImport: &proto.AcquiredJob_TemplateImport{ + Metadata: &sdkproto.Provision_Metadata{ + CoderUrl: srv.AccessURL.String(), + }, + }, + }) + require.NoError(t, err) + require.JSONEq(t, string(want), string(got)) + }) +} + +func TestUpdateJob(t *testing.T) { + t.Parallel() + ctx := context.Background() + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + srv := setup(t) + _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: "hello", + }) + require.ErrorContains(t, err, "invalid UUID") + + _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: uuid.NewString(), + }) + require.ErrorContains(t, err, "no rows in result set") + }) + t.Run("NotRunning", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + }) + require.NoError(t, err) + _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: job.ID.String(), + }) + require.ErrorContains(t, err, "job isn't running yet") + }) + // This test prevents runners from updating jobs they don't own! + t.Run("NotOwner", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: job.ID.String(), + }) + require.ErrorContains(t, err, "you don't own this job") + }) + + setupJob := func(t *testing.T, srv *provisionerdserver.Server) uuid.UUID { + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srv.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + return job.ID + } + + t.Run("Success", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job := setupJob(t, srv) + _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: job.String(), + }) + require.NoError(t, err) + }) + + t.Run("Logs", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job := setupJob(t, srv) + + published := make(chan struct{}) + + closeListener, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) { + close(published) + }) + require.NoError(t, err) + defer closeListener() + + _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: job.String(), + Logs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: sdkproto.LogLevel_INFO, + Output: "hi", + }}, + }) + require.NoError(t, err) + + <-published + }) + t.Run("Readme", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job := setupJob(t, srv) + version, err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + ID: uuid.New(), + JobID: job, + }) + require.NoError(t, err) + _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: job.String(), + Readme: []byte("# hello world"), + }) + require.NoError(t, err) + + version, err = srv.Database.GetTemplateVersionByID(ctx, version.ID) + require.NoError(t, err) + require.Equal(t, "# hello world", version.Readme) + }) +} + +func TestFailJob(t *testing.T) { + t.Parallel() + ctx := context.Background() + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + srv := setup(t) + _, err := srv.FailJob(ctx, &proto.FailedJob{ + JobId: "hello", + }) + require.ErrorContains(t, err, "invalid UUID") + + _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ + JobId: uuid.NewString(), + }) + require.ErrorContains(t, err, "no rows in result set") + }) + // This test prevents runners from updating jobs they don't own! + t.Run("NotOwner", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + _, err = srv.FailJob(ctx, &proto.FailedJob{ + JobId: job.ID.String(), + }) + require.ErrorContains(t, err, "you don't own this job") + }) + t.Run("AlreadyCompleted", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srv.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + err = srv.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: job.ID, + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + require.NoError(t, err) + _, err = srv.FailJob(ctx, &proto.FailedJob{ + JobId: job.ID.String(), + }) + require.ErrorContains(t, err, "job already completed") + }) +} + +func setup(t *testing.T) *provisionerdserver.Server { + t.Helper() + db, pubsub := dbtestutil.NewDB(t) + + return &provisionerdserver.Server{ + ID: uuid.New(), + Logger: slogtest.Make(t, nil), + AccessURL: &url.URL{}, + Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho}, + Database: db, + Pubsub: pubsub, + Telemetry: telemetry.NewNoop(), + } +} From dfa73b3d0a9e8b94bb5ea323d6f06ff5c32e8b11 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 7 Nov 2022 23:13:34 +0000 Subject: [PATCH 3/5] Add workspace builds --- .../provisionerdserver_test.go | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 16fa17344f3ed..9bd4d98959475 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -442,6 +442,60 @@ func TestFailJob(t *testing.T) { }) require.ErrorContains(t, err, "job already completed") }) + t.Run("WorkspaceBuild", func(t *testing.T) { + t.Parallel() + srv := setup(t) + build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + ID: uuid.New(), + }) + require.NoError(t, err) + input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ + WorkspaceBuildID: build.ID, + }) + require.NoError(t, err) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + Input: input, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srv.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + + publishedWorkspace := make(chan struct{}) + closeWorkspaceSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { + close(publishedWorkspace) + }) + require.NoError(t, err) + defer closeWorkspaceSubscribe() + publishedLogs := make(chan struct{}) + closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { + close(publishedLogs) + }) + require.NoError(t, err) + defer closeLogsSubscribe() + + _, err = srv.FailJob(ctx, &proto.FailedJob{ + JobId: job.ID.String(), + Type: &proto.FailedJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.FailedJob_WorkspaceBuild{ + State: []byte("some state"), + }, + }, + }) + require.NoError(t, err) + <-publishedWorkspace + <-publishedLogs + build, err = srv.Database.GetWorkspaceBuildByID(ctx, build.ID) + require.NoError(t, err) + require.Equal(t, "some state", string(build.ProvisionerState)) + }) } func setup(t *testing.T) *provisionerdserver.Server { From a57a934034616a7516ece1eca0aa09df2fae395c Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 8 Nov 2022 00:22:50 +0000 Subject: [PATCH 4/5] Add test for workspace resources --- .../provisionerdserver/provisionerdserver.go | 13 +- .../provisionerdserver_test.go | 263 +++++++++++++++++- 2 files changed, 269 insertions(+), 7 deletions(-) diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index b7dfe0360fe39..96415069e13a3 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -489,7 +489,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete return nil, xerrors.Errorf("get job by id: %w", err) } if job.WorkerID.UUID.String() != server.ID.String() { - return nil, xerrors.Errorf("you don't have permission to update this job") + return nil, xerrors.Errorf("you don't own this job") } telemetrySnapshot := &telemetry.Snapshot{} @@ -509,7 +509,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete slog.F("resource_type", resource.Type), slog.F("transition", transition)) - err = insertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot) + err = InsertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot) if err != nil { return nil, xerrors.Errorf("insert resource: %w", err) } @@ -578,7 +578,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete } // This could be a bulk insert to improve performance. for _, protoResource := range jobType.WorkspaceBuild.Resources { - err = insertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot) + err = InsertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot) if err != nil { return xerrors.Errorf("insert provisioner job: %w", err) } @@ -614,7 +614,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete slog.F("resource_name", resource.Name), slog.F("resource_type", resource.Type)) - err = insertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) + err = InsertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) if err != nil { return nil, xerrors.Errorf("insert resource: %w", err) } @@ -637,6 +637,9 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete } default: + if completed.Type == nil { + return nil, xerrors.Errorf("type payload must be provided") + } return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", reflect.TypeOf(completed.Type).String()) } @@ -655,7 +658,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete return &proto.Empty{}, nil } -func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoResource *sdkproto.Resource, snapshot *telemetry.Snapshot) error { +func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoResource *sdkproto.Resource, snapshot *telemetry.Snapshot) error { resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ ID: uuid.New(), CreatedAt: database.Now(), diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 9bd4d98959475..dff118a8150b8 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -12,7 +12,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" @@ -498,9 +498,268 @@ func TestFailJob(t *testing.T) { }) } +func TestCompleteJob(t *testing.T) { + t.Parallel() + ctx := context.Background() + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + srv := setup(t) + _, err := srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: "hello", + }) + require.ErrorContains(t, err, "invalid UUID") + + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: uuid.NewString(), + }) + require.ErrorContains(t, err, "no rows in result set") + }) + // This test prevents runners from updating jobs they don't own! + t.Run("NotOwner", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + }) + require.ErrorContains(t, err, "you don't own this job") + }) + t.Run("TemplateImport", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srv.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + Type: &proto.CompletedJob_TemplateImport_{ + TemplateImport: &proto.CompletedJob_TemplateImport{ + StartResources: []*sdkproto.Resource{{ + Name: "hello", + Type: "aws_instance", + }}, + StopResources: []*sdkproto.Resource{}, + }, + }, + }) + require.NoError(t, err) + }) + t.Run("WorkspaceBuild", func(t *testing.T) { + t.Parallel() + srv := setup(t) + workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + ID: uuid.New(), + }) + require.NoError(t, err) + build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + ID: uuid.New(), + WorkspaceID: workspace.ID, + Transition: database.WorkspaceTransitionDelete, + }) + require.NoError(t, err) + input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ + WorkspaceBuildID: build.ID, + }) + require.NoError(t, err) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + Input: input, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srv.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + + publishedWorkspace := make(chan struct{}) + closeWorkspaceSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { + close(publishedWorkspace) + }) + require.NoError(t, err) + defer closeWorkspaceSubscribe() + publishedLogs := make(chan struct{}) + closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { + close(publishedLogs) + }) + require.NoError(t, err) + defer closeLogsSubscribe() + + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + Type: &proto.CompletedJob_WorkspaceBuild_{ + WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ + State: []byte{}, + Resources: []*sdkproto.Resource{{ + Name: "example", + Type: "aws_instance", + }}, + }, + }, + }) + require.NoError(t, err) + + <-publishedWorkspace + <-publishedLogs + + workspace, err = srv.Database.GetWorkspaceByID(ctx, workspace.ID) + require.NoError(t, err) + require.True(t, workspace.Deleted) + }) + + t.Run("TemplateDryRun", func(t *testing.T) { + t.Parallel() + srv := setup(t) + job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + }) + require.NoError(t, err) + _, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + WorkerID: uuid.NullUUID{ + UUID: srv.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, + }) + require.NoError(t, err) + + _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ + JobId: job.ID.String(), + Type: &proto.CompletedJob_TemplateDryRun_{ + TemplateDryRun: &proto.CompletedJob_TemplateDryRun{ + Resources: []*sdkproto.Resource{{ + Name: "something", + Type: "aws_instance", + }}, + }, + }, + }) + require.NoError(t, err) + }) +} + +func TestInsertWorkspaceResource(t *testing.T) { + t.Parallel() + ctx := context.Background() + insert := func(db database.Store, jobID uuid.UUID, resource *sdkproto.Resource) error { + return provisionerdserver.InsertWorkspaceResource(ctx, db, jobID, database.WorkspaceTransitionStart, resource, &telemetry.Snapshot{}) + } + t.Run("NoAgents", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + job := uuid.New() + err := insert(db, job, &sdkproto.Resource{ + Name: "something", + Type: "aws_instance", + }) + require.NoError(t, err) + resources, err := db.GetWorkspaceResourcesByJobID(ctx, job) + require.NoError(t, err) + require.Len(t, resources, 1) + }) + t.Run("InvalidAgentToken", func(t *testing.T) { + t.Parallel() + err := insert(databasefake.New(), uuid.New(), &sdkproto.Resource{ + Name: "something", + Type: "aws_instance", + Agents: []*sdkproto.Agent{{ + Auth: &sdkproto.Agent_Token{ + Token: "bananas", + }, + }}, + }) + require.ErrorContains(t, err, "invalid UUID length") + }) + t.Run("DuplicateApps", func(t *testing.T) { + t.Parallel() + err := insert(databasefake.New(), uuid.New(), &sdkproto.Resource{ + Name: "something", + Type: "aws_instance", + Agents: []*sdkproto.Agent{{ + Apps: []*sdkproto.App{{ + Slug: "a", + }, { + Slug: "a", + }}, + }}, + }) + require.ErrorContains(t, err, "duplicate app slug") + }) + t.Run("Success", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + job := uuid.New() + err := insert(db, job, &sdkproto.Resource{ + Name: "something", + Type: "aws_instance", + Agents: []*sdkproto.Agent{{ + Name: "dev", + Env: map[string]string{ + "something": "test", + }, + StartupScript: "value", + OperatingSystem: "linux", + Architecture: "amd64", + Auth: &sdkproto.Agent_Token{ + Token: uuid.NewString(), + }, + Apps: []*sdkproto.App{{ + Slug: "a", + }}, + }}, + }) + require.NoError(t, err) + resources, err := db.GetWorkspaceResourcesByJobID(ctx, job) + require.NoError(t, err) + require.Len(t, resources, 1) + agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID}) + require.NoError(t, err) + require.Len(t, agents, 1) + agent := agents[0] + require.Equal(t, "amd64", agent.Architecture) + require.Equal(t, "linux", agent.OperatingSystem) + require.Equal(t, "value", agent.StartupScript.String) + want, err := json.Marshal(map[string]string{ + "something": "test", + }) + require.NoError(t, err) + got, err := agent.EnvironmentVariables.RawMessage.MarshalJSON() + require.NoError(t, err) + require.Equal(t, want, got) + }) +} + func setup(t *testing.T) *provisionerdserver.Server { t.Helper() - db, pubsub := dbtestutil.NewDB(t) + db := databasefake.New() + pubsub := database.NewPubsubInMemory() return &provisionerdserver.Server{ ID: uuid.New(), From aeede0ec2428da7cb7f2d8695cc2283d36cfd186 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 8 Nov 2022 00:59:19 +0000 Subject: [PATCH 5/5] Disable flakey test --- cli/loadtest_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/loadtest_test.go b/cli/loadtest_test.go index 544031e92f3e2..eda0084372d9e 100644 --- a/cli/loadtest_test.go +++ b/cli/loadtest_test.go @@ -138,6 +138,7 @@ func TestLoadTest(t *testing.T) { t.Run("OutputFormats", func(t *testing.T) { t.Parallel() + t.Skip("This test is flakey. See: https://github.com/coder/coder/actions/runs/3415360091/jobs/5684401383") type outputFlag struct { format string