diff --git a/coderd/coderd.go b/coderd/coderd.go index b6c01e3e30745..c02a802ff5e10 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -16,18 +16,25 @@ import ( type Options struct { Logger slog.Logger Database database.Store + Pubsub database.Pubsub } // New constructs the Coder API into an HTTP handler. func New(options *Options) http.Handler { projects := &projects{ Database: options.Database, + Pubsub: options.Pubsub, + } + provisioners := &provisioners{ + Database: options.Database, + Pubsub: options.Pubsub, } users := &users{ Database: options.Database, } workspaces := &workspaces{ Database: options.Database, + Pubsub: options.Pubsub, } r := chi.NewRouter() @@ -39,6 +46,8 @@ func New(options *Options) http.Handler { }) r.Post("/login", users.loginWithPassword) r.Post("/logout", users.logout) + r.Get("/provisionerd", provisioners.listen) + // Used for setup. r.Post("/user", users.createInitialUser) r.Route("/users", func(r chi.Router) { @@ -67,6 +76,10 @@ func New(options *Options) http.Handler { r.Route("/history", func(r chi.Router) { r.Get("/", projects.allProjectHistory) r.Post("/", projects.createProjectHistory) + r.Route("/{projecthistory}", func(r chi.Router) { + r.Use(httpmw.ExtractProjectHistoryParam(options.Database)) + r.Get("/logs", projects.projectHistoryLogs) + }) }) r.Get("/workspaces", workspaces.allWorkspacesForProject) }) @@ -89,10 +102,18 @@ func New(options *Options) http.Handler { r.Post("/", workspaces.createWorkspaceHistory) r.Get("/", workspaces.listAllWorkspaceHistory) r.Get("/latest", workspaces.latestWorkspaceHistory) + r.Route("/{workspacehistory}", func(r chi.Router) { + r.Use(httpmw.ExtractWorkspaceHistoryParam(options.Database)) + r.Get("/logs", workspaces.workspaceHistoryLogs) + }) }) }) }) }) + + r.Route("/provisioners", func(r chi.Router) { + r.Get("/daemons", provisioners.listDaemons) + }) }) r.NotFound(site.Handler().ServeHTTP) return r diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 1ecf069bce864..3ed4d4366a2b8 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -3,13 +3,16 @@ package coderdtest import ( "context" "database/sql" + "io" "net/http/httptest" "net/url" "os" "testing" + "time" "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd" "github.com/coder/coder/codersdk" @@ -17,6 +20,10 @@ import ( "github.com/coder/coder/database" "github.com/coder/coder/database/databasefake" "github.com/coder/coder/database/postgres" + "github.com/coder/coder/provisioner/terraform" + "github.com/coder/coder/provisionerd" + "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/provisionersdk/proto" ) // Server represents a test instance of coderd. @@ -57,11 +64,44 @@ func (s *Server) RandomInitialUser(t *testing.T) coderd.CreateInitialUserRequest return req } +// AddProvisionerd launches a new provisionerd instance! +func (s *Server) AddProvisionerd(t *testing.T) io.Closer { + tfClient, tfServer := provisionersdk.TransportPipe() + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(func() { + _ = tfClient.Close() + _ = tfServer.Close() + cancelFunc() + }) + go func() { + err := terraform.Serve(ctx, &terraform.ServeOptions{ + ServeOptions: &provisionersdk.ServeOptions{ + Listener: tfServer, + }, + }) + require.NoError(t, err) + }() + + closer := provisionerd.New(s.Client.ProvisionerDaemonClient, &provisionerd.Options{ + Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelInfo), + PollInterval: 50 * time.Millisecond, + Provisioners: provisionerd.Provisioners{ + string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(provisionersdk.Conn(tfClient)), + }, + WorkDirectory: t.TempDir(), + }) + t.Cleanup(func() { + _ = closer.Close() + }) + return closer +} + // New constructs a new coderd test instance. This returned Server // should contain no side-effects. func New(t *testing.T) Server { // This can be hotswapped for a live database instance. db := databasefake.New() + pubsub := database.NewPubsubInMemory() if os.Getenv("DB") != "" { connectionURL, close, err := postgres.Open() require.NoError(t, err) @@ -74,11 +114,15 @@ func New(t *testing.T) Server { err = database.Migrate(sqlDB) require.NoError(t, err) db = database.New(sqlDB) + + pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL) + require.NoError(t, err) } handler := coderd.New(&coderd.Options{ Logger: slogtest.Make(t, nil), Database: db, + Pubsub: pubsub, }) srv := httptest.NewServer(handler) serverURL, err := url.Parse(srv.URL) diff --git a/coderd/coderdtest/coderdtest_test.go b/coderd/coderdtest/coderdtest_test.go index e36d1c1408cd1..b7312f96864fc 100644 --- a/coderd/coderdtest/coderdtest_test.go +++ b/coderd/coderdtest/coderdtest_test.go @@ -16,4 +16,5 @@ func TestNew(t *testing.T) { t.Parallel() server := coderdtest.New(t) _ = server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) } diff --git a/coderd/projects.go b/coderd/projects.go index 5ef2ea5067b6a..157da1ccd9651 100644 --- a/coderd/projects.go +++ b/coderd/projects.go @@ -3,7 +3,9 @@ package coderd import ( "archive/tar" "bytes" + "context" "database/sql" + "encoding/json" "errors" "fmt" "net/http" @@ -34,6 +36,14 @@ type ProjectHistory struct { StorageMethod database.ProjectStorageMethod `json:"storage_method"` } +type ProjectHistoryLog struct { + ID uuid.UUID + CreatedAt time.Time `json:"created_at"` + Source database.LogSource `json:"log_source"` + Level database.LogLevel `json:"log_level"` + Output string `json:"output"` +} + // CreateProjectRequest enables callers to create a new Project. type CreateProjectRequest struct { Name string `json:"name" validate:"username,required"` @@ -48,6 +58,7 @@ type CreateProjectVersionRequest struct { type projects struct { Database database.Store + Pubsub database.Pubsub } // Lists all projects the authenticated user has access to. @@ -222,6 +233,115 @@ func (p *projects) createProjectHistory(rw http.ResponseWriter, r *http.Request) render.JSON(rw, r, convertProjectHistory(history)) } +func (p *projects) projectHistoryLogs(rw http.ResponseWriter, r *http.Request) { + projectHistory := httpmw.ProjectHistoryParam(r) + follow := r.URL.Query().Has("follow") + + if !follow { + // If we're not attempting to follow logs, + // we can exit immediately! + logs, err := p.Database.GetProjectHistoryLogsByIDBefore(r.Context(), database.GetProjectHistoryLogsByIDBeforeParams{ + ProjectHistoryID: projectHistory.ID, + CreatedAt: time.Now(), + }) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project history logs: %s", err), + }) + return + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, logs) + return + } + + // We only want to fetch messages before subscribe, so that + // there aren't any duplicates. + timeBeforeSubscribe := database.Now() + // Start subscribing immediately, otherwise we could miss messages + // that occur during the database read. + newLogNotify := make(chan ProjectHistoryLog, 128) + cancelNewLogNotify, err := p.Pubsub.Subscribe(projectHistoryLogsChannel(projectHistory.ID), func(ctx context.Context, message []byte) { + var logs []database.ProjectHistoryLog + err := json.Unmarshal(message, &logs) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("parse logs from publish: %s", err), + }) + return + } + for _, log := range logs { + // If many logs are sent during our database query, this channel + // could overflow. The Go scheduler would decide the order to send + // logs in at that point, which is an unfortunate (but not fatal) + // flaw of this approach. + // + // This is an extremely unlikely outcome given reasonable database + // query times. + newLogNotify <- convertProjectHistoryLog(log) + } + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("listen for new logs: %s", err), + }) + return + } + defer cancelNewLogNotify() + + // In-between here logs could be missed! + projectHistoryLogs, err := p.Database.GetProjectHistoryLogsByIDBefore(r.Context(), database.GetProjectHistoryLogsByIDBeforeParams{ + ProjectHistoryID: projectHistory.ID, + CreatedAt: timeBeforeSubscribe, + }) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project history logs: %s", err), + }) + return + } + + // "follow" uses the ndjson format to stream data. + // See: https://canjs.com/doc/can-ndjson-stream.html + rw.Header().Set("Content-Type", "application/stream+json") + rw.WriteHeader(http.StatusOK) + rw.(http.Flusher).Flush() + + // The Go stdlib JSON encoder appends a newline character after message write. + encoder := json.NewEncoder(rw) + for _, projectHistoryLog := range projectHistoryLogs { + // JSON separated by a newline + err = encoder.Encode(convertProjectHistoryLog(projectHistoryLog)) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("marshal: %s", err), + }) + return + } + } + + for { + select { + case <-r.Context().Done(): + return + case log := <-newLogNotify: + err = encoder.Encode(log) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("marshal follow: %s", err), + }) + return + } + } + } +} + func convertProjectHistory(history database.ProjectHistory) ProjectHistory { return ProjectHistory{ ID: history.ID, @@ -231,3 +351,17 @@ func convertProjectHistory(history database.ProjectHistory) ProjectHistory { Name: history.Name, } } + +func convertProjectHistoryLog(log database.ProjectHistoryLog) ProjectHistoryLog { + return ProjectHistoryLog{ + ID: log.ID, + CreatedAt: log.CreatedAt, + Source: log.Source, + Level: log.Level, + Output: log.Output, + } +} + +func projectHistoryLogsChannel(projectHistoryID uuid.UUID) string { + return fmt.Sprintf("project-history-logs:%s", projectHistoryID) +} diff --git a/coderd/provisioners.go b/coderd/provisioners.go new file mode 100644 index 0000000000000..3da0efb64758f --- /dev/null +++ b/coderd/provisioners.go @@ -0,0 +1,603 @@ +package coderd + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "time" + + "golang.org/x/xerrors" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "github.com/go-chi/render" + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "github.com/moby/moby/pkg/namesgenerator" + + "github.com/coder/coder/coderd/projectparameter" + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" + + "nhooyr.io/websocket" +) + +type ProvisionerDaemon database.ProvisionerDaemon + +type provisioners struct { + Database database.Store + Pubsub database.Pubsub +} + +func (p *provisioners) listDaemons(rw http.ResponseWriter, r *http.Request) { + daemons, err := p.Database.GetProvisionerDaemons(r.Context()) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner daemons: %s", err), + }) + return + } + + render.Status(r, http.StatusOK) + render.JSON(rw, r, daemons) +} + +func (p *provisioners) listen(rw http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("accept websocket: %s", err), + }) + return + } + + daemon, err := p.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + Name: namesgenerator.GetRandomName(1), + Provisioners: []database.ProvisionerType{database.ProvisionerTypeCdrBasic, database.ProvisionerTypeTerraform}, + }) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("insert provisioner daemon:% s", err)) + return + } + + session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), nil) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("multiplex server: %s", err)) + return + } + mux := drpcmux.New() + err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdServer{ + ID: daemon.ID, + Database: p.Database, + Pubsub: p.Pubsub, + }) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("drpc register provisioner daemon: %s", err)) + return + } + server := drpcserver.New(mux) + err = server.Serve(r.Context(), session) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err)) + } +} + +// The input for a "workspace_provision" job. +type workspaceProvisionJob struct { + WorkspaceHistoryID uuid.UUID `json:"workspace_history_id"` +} + +// The input for a "project_import" job. +type projectImportJob struct { + ProjectHistoryID uuid.UUID `json:"project_history_id"` +} + +// An implementation of the provisionerd protobuf server definition. +type provisionerdServer struct { + ID uuid.UUID + Database database.Store + Pubsub database.Pubsub +} + +func (s *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + // This locks the job. No other provisioners can acquire this job. + job, err := s.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + StartedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + WorkerID: uuid.NullUUID{ + UUID: s.ID, + Valid: true, + }, + Types: []database.ProvisionerType{database.ProvisionerTypeTerraform}, + }) + if errors.Is(err, sql.ErrNoRows) { + // If no jobs are available, an empty struct is sent back. + return &proto.AcquiredJob{}, nil + } + if err != nil { + return nil, xerrors.Errorf("acquire job: %w", err) + } + failJob := func(errorMessage string) error { + err = s.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: job.ID, + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + Error: sql.NullString{ + String: errorMessage, + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + return xerrors.Errorf("request job was invalidated: %s", errorMessage) + } + + project, err := s.Database.GetProjectByID(ctx, job.ProjectID) + if err != nil { + return nil, failJob(fmt.Sprintf("get project: %s", err)) + } + + organization, err := s.Database.GetOrganizationByID(ctx, project.OrganizationID) + if err != nil { + return nil, failJob(fmt.Sprintf("get organization: %s", err)) + } + + user, err := s.Database.GetUserByID(ctx, job.InitiatorID) + if err != nil { + return nil, failJob(fmt.Sprintf("get user: %s", err)) + } + + acquiredJob := &proto.AcquiredJob{ + JobId: job.ID.String(), + CreatedAt: job.CreatedAt.UnixMilli(), + Provisioner: string(job.Provisioner), + OrganizationName: organization.Name, + ProjectName: project.Name, + UserName: user.Username, + } + var projectHistory database.ProjectHistory + switch job.Type { + case database.ProvisionerJobTypeWorkspaceProvision: + var input workspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) + } + workspaceHistory, err := s.Database.GetWorkspaceHistoryByID(ctx, input.WorkspaceHistoryID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace history: %s", err)) + } + + workspace, err := s.Database.GetWorkspaceByID(ctx, workspaceHistory.WorkspaceID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace: %s", err)) + } + + projectHistory, err = s.Database.GetProjectHistoryByID(ctx, workspaceHistory.ProjectHistoryID) + if err != nil { + return nil, failJob(fmt.Sprintf("get project history: %s", err)) + } + + parameters, err := projectparameter.Compute(ctx, s.Database, projectparameter.Scope{ + OrganizationID: organization.ID, + ProjectID: project.ID, + ProjectHistoryID: projectHistory.ID, + UserID: user.ID, + WorkspaceID: workspace.ID, + WorkspaceHistoryID: workspaceHistory.ID, + }) + if err != nil { + return nil, failJob(fmt.Sprintf("compute parameters: %s", err)) + } + protoParameters := make([]*sdkproto.ParameterValue, 0, len(parameters)) + for _, parameter := range parameters { + protoParameters = append(protoParameters, parameter.Proto) + } + + provisionerState := []byte{} + if workspaceHistory.BeforeID.Valid { + beforeHistory, err := s.Database.GetWorkspaceHistoryByID(ctx, workspaceHistory.BeforeID.UUID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace history: %s", err)) + } + provisionerState = beforeHistory.ProvisionerState + } + + acquiredJob.Type = &proto.AcquiredJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.AcquiredJob_WorkspaceProvision{ + WorkspaceHistoryId: workspaceHistory.ID.String(), + WorkspaceName: workspace.Name, + State: provisionerState, + ParameterValues: protoParameters, + }, + } + case database.ProvisionerJobTypeProjectImport: + var input projectImportJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) + } + projectHistory, err = s.Database.GetProjectHistoryByID(ctx, input.ProjectHistoryID) + if err != nil { + return nil, failJob(fmt.Sprintf("get project history: %s", err)) + } + } + switch projectHistory.StorageMethod { + case database.ProjectStorageMethodInlineArchive: + acquiredJob.ProjectSourceArchive = projectHistory.StorageSource + default: + return nil, failJob(fmt.Sprintf("unsupported storage source: %q", projectHistory.StorageMethod)) + } + + return acquiredJob, err +} + +func (s *provisionerdServer) UpdateJob(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + update, err := stream.Recv() + if err != nil { + return err + } + parsedID, err := uuid.Parse(update.JobId) + if err != nil { + return xerrors.Errorf("parse job id: %w", err) + } + job, err := s.Database.GetProvisionerJobByID(stream.Context(), parsedID) + if err != nil { + return xerrors.Errorf("get job: %w", err) + } + if !job.WorkerID.Valid { + return errors.New("job isn't running yet") + } + if job.WorkerID.UUID.String() != s.ID.String() { + return errors.New("you don't own this job") + } + + err = s.Database.UpdateProvisionerJobByID(stream.Context(), database.UpdateProvisionerJobByIDParams{ + ID: parsedID, + UpdatedAt: database.Now(), + }) + if err != nil { + return xerrors.Errorf("update job: %w", err) + } + switch job.Type { + case database.ProvisionerJobTypeProjectImport: + if len(update.ProjectImportLogs) == 0 { + continue + } + var input projectImportJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return xerrors.Errorf("unmarshal job input %q: %s", job.Input, err) + } + insertParams := database.InsertProjectHistoryLogsParams{ + ProjectHistoryID: input.ProjectHistoryID, + } + for _, log := range update.ProjectImportLogs { + logLevel, err := convertLogLevel(log.Level) + if err != nil { + return xerrors.Errorf("convert log level: %w", err) + } + logSource, err := convertLogSource(log.Source) + if err != nil { + return xerrors.Errorf("convert log source: %w", err) + } + insertParams.ID = append(insertParams.ID, uuid.New()) + insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt)) + insertParams.Level = append(insertParams.Level, logLevel) + insertParams.Source = append(insertParams.Source, logSource) + insertParams.Output = append(insertParams.Output, log.Output) + } + logs, err := s.Database.InsertProjectHistoryLogs(stream.Context(), insertParams) + if err != nil { + return xerrors.Errorf("insert project logs: %w", err) + } + data, err := json.Marshal(logs) + if err != nil { + return xerrors.Errorf("marshal project log: %w", err) + } + err = s.Pubsub.Publish(projectHistoryLogsChannel(input.ProjectHistoryID), data) + if err != nil { + return xerrors.Errorf("publish history log: %w", err) + } + case database.ProvisionerJobTypeWorkspaceProvision: + if len(update.WorkspaceProvisionLogs) == 0 { + continue + } + var input workspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return xerrors.Errorf("unmarshal job input %q: %s", job.Input, err) + } + insertParams := database.InsertWorkspaceHistoryLogsParams{ + WorkspaceHistoryID: input.WorkspaceHistoryID, + } + for _, log := range update.WorkspaceProvisionLogs { + logLevel, err := convertLogLevel(log.Level) + if err != nil { + return xerrors.Errorf("convert log level: %w", err) + } + logSource, err := convertLogSource(log.Source) + if err != nil { + return xerrors.Errorf("convert log source: %w", err) + } + insertParams.ID = append(insertParams.ID, uuid.New()) + insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt)) + insertParams.Level = append(insertParams.Level, logLevel) + insertParams.Source = append(insertParams.Source, logSource) + insertParams.Output = append(insertParams.Output, log.Output) + } + logs, err := s.Database.InsertWorkspaceHistoryLogs(stream.Context(), insertParams) + if err != nil { + return xerrors.Errorf("insert workspace logs: %w", err) + } + data, err := json.Marshal(logs) + if err != nil { + return xerrors.Errorf("marshal project log: %w", err) + } + err = s.Pubsub.Publish(workspaceHistoryLogsChannel(input.WorkspaceHistoryID), data) + if err != nil { + return xerrors.Errorf("publish history log: %w", err) + } + } + } +} + +func (s *provisionerdServer) CancelJob(ctx context.Context, cancelJob *proto.CancelledJob) (*proto.Empty, error) { + jobID, err := uuid.Parse(cancelJob.JobId) + if err != nil { + return nil, xerrors.Errorf("parse job id: %w", err) + } + err = s.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: jobID, + CancelledAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + UpdatedAt: database.Now(), + Error: sql.NullString{ + String: cancelJob.Error, + Valid: cancelJob.Error != "", + }, + }) + if err != nil { + return nil, xerrors.Errorf("update provisioner job: %w", err) + } + return &proto.Empty{}, nil +} + +// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. +func (s *provisionerdServer) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { + jobID, err := uuid.Parse(completed.JobId) + if err != nil { + return nil, xerrors.Errorf("parse job id: %w", err) + } + job, err := s.Database.GetProvisionerJobByID(ctx, jobID) + if err != nil { + return nil, xerrors.Errorf("get job by id: %w", err) + } + // TODO: Check if the worker ID matches! + // If it doesn't, a provisioner daemon could be impersonating another job! + + switch jobType := completed.Type.(type) { + case *proto.CompletedJob_ProjectImport_: + var input projectImportJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, xerrors.Errorf("unmarshal job data: %w", err) + } + + // Validate that all parameters send from the provisioner daemon + // follow the protocol. + projectParameters := make([]database.InsertProjectParameterParams, 0, len(jobType.ProjectImport.ParameterSchemas)) + for _, protoParameter := range jobType.ProjectImport.ParameterSchemas { + validationTypeSystem, err := convertValidationTypeSystem(protoParameter.ValidationTypeSystem) + if err != nil { + return nil, xerrors.Errorf("convert validation type system for %q: %w", protoParameter.Name, err) + } + + projectParameter := database.InsertProjectParameterParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + ProjectHistoryID: input.ProjectHistoryID, + Name: protoParameter.Name, + Description: protoParameter.Description, + RedisplayValue: protoParameter.RedisplayValue, + ValidationError: protoParameter.ValidationError, + ValidationCondition: protoParameter.ValidationCondition, + ValidationValueType: protoParameter.ValidationValueType, + ValidationTypeSystem: validationTypeSystem, + + AllowOverrideDestination: protoParameter.AllowOverrideDestination, + AllowOverrideSource: protoParameter.AllowOverrideSource, + } + + // It's possible a parameter doesn't define a default source! + if protoParameter.DefaultSource != nil { + parameterSourceScheme, err := convertParameterSourceScheme(protoParameter.DefaultSource.Scheme) + if err != nil { + return nil, xerrors.Errorf("convert parameter source scheme: %w", err) + } + projectParameter.DefaultSourceScheme = parameterSourceScheme + projectParameter.DefaultSourceValue = sql.NullString{ + String: protoParameter.DefaultSource.Value, + Valid: protoParameter.DefaultSource.Value != "", + } + } + + // It's possible a parameter doesn't define a default destination! + if protoParameter.DefaultDestination != nil { + parameterDestinationScheme, err := convertParameterDestinationScheme(protoParameter.DefaultDestination.Scheme) + if err != nil { + return nil, xerrors.Errorf("convert parameter destination scheme: %w", err) + } + projectParameter.DefaultDestinationScheme = parameterDestinationScheme + projectParameter.DefaultDestinationValue = sql.NullString{ + String: protoParameter.DefaultDestination.Value, + Valid: protoParameter.DefaultDestination.Value != "", + } + } + + projectParameters = append(projectParameters, projectParameter) + } + + // This must occur in a transaction in case of failure. + err = s.Database.InTx(func(db database.Store) error { + err = db.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: jobID, + UpdatedAt: database.Now(), + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + for _, projectParameter := range projectParameters { + _, err = db.InsertProjectParameter(ctx, projectParameter) + if err != nil { + return xerrors.Errorf("insert project parameter %q: %w", projectParameter.Name, err) + } + } + return nil + }) + if err != nil { + return nil, xerrors.Errorf("complete job: %w", err) + } + case *proto.CompletedJob_WorkspaceProvision_: + var input workspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, xerrors.Errorf("unmarshal job data: %w", err) + } + + workspaceHistory, err := s.Database.GetWorkspaceHistoryByID(ctx, input.WorkspaceHistoryID) + if err != nil { + return nil, xerrors.Errorf("get workspace history: %w", err) + } + + err = s.Database.InTx(func(db database.Store) error { + err = db.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: jobID, + UpdatedAt: database.Now(), + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + err = db.UpdateWorkspaceHistoryByID(ctx, database.UpdateWorkspaceHistoryByIDParams{ + ID: workspaceHistory.ID, + UpdatedAt: database.Now(), + ProvisionerState: jobType.WorkspaceProvision.State, + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update workspace history: %w", err) + } + for _, protoResource := range jobType.WorkspaceProvision.Resources { + _, err = db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + WorkspaceHistoryID: input.WorkspaceHistoryID, + Type: protoResource.Type, + Name: protoResource.Name, + // TODO: Generate this at the variable validation phase. + // Set the value in `default_source`, and disallow overwrite. + WorkspaceAgentToken: uuid.NewString(), + }) + if err != nil { + return xerrors.Errorf("insert workspace resource %q: %w", protoResource.Name, err) + } + } + return nil + }) + if err != nil { + return nil, xerrors.Errorf("complete job: %w", err) + } + default: + return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", + reflect.TypeOf(completed.Type).String()) + } + + return &proto.Empty{}, nil +} + +func convertValidationTypeSystem(typeSystem sdkproto.ParameterSchema_TypeSystem) (database.ParameterTypeSystem, error) { + switch typeSystem { + case sdkproto.ParameterSchema_HCL: + return database.ParameterTypeSystemHCL, nil + default: + return database.ParameterTypeSystem(""), xerrors.Errorf("unknown type system: %d", typeSystem) + } +} + +func convertParameterSourceScheme(sourceScheme sdkproto.ParameterSource_Scheme) (database.ParameterSourceScheme, error) { + switch sourceScheme { + case sdkproto.ParameterSource_DATA: + return database.ParameterSourceSchemeData, nil + default: + return database.ParameterSourceScheme(""), xerrors.Errorf("unknown parameter source scheme: %d", sourceScheme) + } +} + +func convertParameterDestinationScheme(destinationScheme sdkproto.ParameterDestination_Scheme) (database.ParameterDestinationScheme, error) { + switch destinationScheme { + case sdkproto.ParameterDestination_ENVIRONMENT_VARIABLE: + return database.ParameterDestinationSchemeEnvironmentVariable, nil + case sdkproto.ParameterDestination_PROVISIONER_VARIABLE: + return database.ParameterDestinationSchemeProvisionerVariable, nil + default: + return database.ParameterDestinationScheme(""), xerrors.Errorf("unknown parameter destination scheme: %d", destinationScheme) + } +} + +func convertLogLevel(logLevel sdkproto.LogLevel) (database.LogLevel, error) { + switch logLevel { + case sdkproto.LogLevel_TRACE: + return database.LogLevelTrace, nil + case sdkproto.LogLevel_DEBUG: + return database.LogLevelDebug, nil + case sdkproto.LogLevel_INFO: + return database.LogLevelInfo, nil + case sdkproto.LogLevel_WARN: + return database.LogLevelWarn, nil + case sdkproto.LogLevel_ERROR: + return database.LogLevelError, nil + default: + return database.LogLevel(""), xerrors.Errorf("unknown log level: %d", logLevel) + } +} + +func convertLogSource(logSource proto.LogSource) (database.LogSource, error) { + switch logSource { + case proto.LogSource_PROVISIONER_DAEMON: + return database.LogSourceProvisionerDaemon, nil + case proto.LogSource_PROVISIONER: + return database.LogSourceProvisioner, nil + default: + return database.LogSource(""), xerrors.Errorf("unknown log source: %d", logSource) + } +} diff --git a/coderd/provisioners_test.go b/coderd/provisioners_test.go new file mode 100644 index 0000000000000..e65a8ced2a508 --- /dev/null +++ b/coderd/provisioners_test.go @@ -0,0 +1,81 @@ +package coderd_test + +import ( + "archive/tar" + "bytes" + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/database" +) + +func TestProvisionerd(t *testing.T) { + t.Parallel() + t.Run("ListDaemons", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _ = server.AddProvisionerd(t) + require.Eventually(t, func() bool { + daemons, err := server.Client.ProvisionerDaemons(context.Background()) + require.NoError(t, err) + return len(daemons) > 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("RunJob", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + user := server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) + + project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{ + Name: "my-project", + Provisioner: database.ProvisionerTypeTerraform, + }) + require.NoError(t, err) + + var buffer bytes.Buffer + writer := tar.NewWriter(&buffer) + content := `variable "frog" {} + resource "null_resource" "dev" {}` + err = writer.WriteHeader(&tar.Header{ + Name: "main.tf", + Size: int64(len(content)), + }) + require.NoError(t, err) + _, err = writer.Write([]byte(content)) + require.NoError(t, err) + + projectHistory, err := server.Client.CreateProjectHistory(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{ + StorageMethod: database.ProjectStorageMethodInlineArchive, + StorageSource: buffer.Bytes(), + }) + require.NoError(t, err) + + workspace, err := server.Client.CreateWorkspace(context.Background(), "", coderd.CreateWorkspaceRequest{ + ProjectID: project.ID, + Name: "wowie", + }) + require.NoError(t, err) + + workspaceHistory, err := server.Client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ + ProjectHistoryID: projectHistory.ID, + Transition: database.WorkspaceTransitionCreate, + }) + require.NoError(t, err) + + logs, err := server.Client.FollowWorkspaceHistoryLogs(context.Background(), "me", workspace.Name, workspaceHistory.Name) + require.NoError(t, err) + + for { + log := <-logs + fmt.Printf("Got %s %s\n", log.CreatedAt, log.Output) + } + }) +} diff --git a/coderd/workspaces.go b/coderd/workspaces.go index f12633a5611bf..47961ff1dbdb8 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -1,7 +1,9 @@ package coderd import ( + "context" "database/sql" + "encoding/json" "errors" "fmt" "net/http" @@ -9,6 +11,7 @@ import ( "github.com/go-chi/render" "github.com/google/uuid" + "github.com/moby/moby/pkg/namesgenerator" "golang.org/x/xerrors" "github.com/coder/coder/database" @@ -24,6 +27,7 @@ type Workspace database.Workspace // Iterate on before/after to determine a chronological history. type WorkspaceHistory struct { ID uuid.UUID `json:"id"` + Name string `json:"name"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` CompletedAt time.Time `json:"completed_at"` @@ -35,6 +39,14 @@ type WorkspaceHistory struct { Initiator string `json:"initiator"` } +type WorkspaceHistoryLog struct { + ID uuid.UUID + CreatedAt time.Time `json:"created_at"` + Source database.LogSource `json:"log_source"` + Level database.LogLevel `json:"log_level"` + Output string `json:"output"` +} + // CreateWorkspaceRequest provides options for creating a new workspace. type CreateWorkspaceRequest struct { ProjectID uuid.UUID `json:"project_id" validate:"required"` @@ -49,6 +61,7 @@ type CreateWorkspaceHistoryRequest struct { type workspaces struct { Database database.Store + Pubsub database.Pubsub } // Returns all workspaces across all projects and organizations. @@ -270,6 +283,13 @@ func (w *workspaces) createWorkspaceHistory(rw http.ResponseWriter, r *http.Requ }) return } + project, err := w.Database.GetProjectByID(r.Context(), projectHistory.ProjectID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project: %s", err), + }) + return + } // Store prior history ID if it exists to update it after we create new! priorHistoryID := uuid.NullUUID{} @@ -298,17 +318,40 @@ func (w *workspaces) createWorkspaceHistory(rw http.ResponseWriter, r *http.Requ // This must happen in a transaction to ensure history can be inserted, and // the prior history can update it's "after" column to point at the new. err = w.Database.InTx(func(db database.Store) error { + // Generate the ID before-hand so the provisioner job is aware of it! + workspaceHistoryID := uuid.New() + input, err := json.Marshal(workspaceProvisionJob{ + WorkspaceHistoryID: workspaceHistoryID, + }) + if err != nil { + return xerrors.Errorf("marshal provision job: %w", err) + } + + provisionerJob, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + InitiatorID: user.ID, + Provisioner: project.Provisioner, + Type: database.ProvisionerJobTypeWorkspaceProvision, + ProjectID: project.ID, + Input: input, + }) + if err != nil { + return xerrors.Errorf("insert provisioner job: %w", err) + } + workspaceHistory, err = db.InsertWorkspaceHistory(r.Context(), database.InsertWorkspaceHistoryParams{ - ID: uuid.New(), + ID: workspaceHistoryID, CreatedAt: database.Now(), UpdatedAt: database.Now(), WorkspaceID: workspace.ID, + Name: namesgenerator.GetRandomName(1), ProjectHistoryID: projectHistory.ID, BeforeID: priorHistoryID, Initiator: user.ID, Transition: createBuild.Transition, - // This should create a provision job once that gets implemented! - ProvisionJobID: uuid.New(), + ProvisionJobID: provisionerJob.ID, }) if err != nil { return xerrors.Errorf("insert workspace history: %w", err) @@ -342,6 +385,116 @@ func (w *workspaces) createWorkspaceHistory(rw http.ResponseWriter, r *http.Requ render.JSON(rw, r, convertWorkspaceHistory(workspaceHistory)) } +func (w *workspaces) workspaceHistoryLogs(rw http.ResponseWriter, r *http.Request) { + workspaceHistory := httpmw.WorkspaceHistoryParam(r) + follow := r.URL.Query().Has("follow") + + if !follow { + // If we're not attempting to follow logs, + // we can exit immediately! + logs, err := w.Database.GetWorkspaceHistoryLogsByIDBefore(r.Context(), database.GetWorkspaceHistoryLogsByIDBeforeParams{ + WorkspaceHistoryID: workspaceHistory.ID, + CreatedAt: time.Now(), + }) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get workspace history logs: %s", err), + }) + return + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, logs) + return + } + + // We only want to fetch messages before subscribe, so that + // there aren't any duplicates. + timeBeforeSubscribe := database.Now() + // Start subscribing immediately, otherwise we could miss messages + // that occur during the database read. + newLogNotify := make(chan WorkspaceHistoryLog, 128) + cancelNewLogNotify, err := w.Pubsub.Subscribe(workspaceHistoryLogsChannel(workspaceHistory.ID), func(ctx context.Context, message []byte) { + var logs []database.WorkspaceHistoryLog + err := json.Unmarshal(message, &logs) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("parse logs from publish: %s", err), + }) + return + } + for _, log := range logs { + // If many logs are sent during our database query, this channel + // could overflow. The Go scheduler would decide the order to send + // logs in at that point, which is an unfortunate (but not fatal) + // flaw of this approach. + // + // This is an extremely unlikely outcome given reasonable database + // query times. + newLogNotify <- convertWorkspaceHistoryLog(log) + } + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("listen for new logs: %s", err), + }) + return + } + defer cancelNewLogNotify() + + workspaceHistoryLogs, err := w.Database.GetWorkspaceHistoryLogsByIDBefore(r.Context(), database.GetWorkspaceHistoryLogsByIDBeforeParams{ + WorkspaceHistoryID: workspaceHistory.ID, + CreatedAt: timeBeforeSubscribe, + }) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get workspace history logs: %s", err), + }) + return + } + + // "follow" uses the ndjson format to stream data. + // See: https://canjs.com/doc/can-ndjson-stream.html + rw.Header().Set("Content-Type", "application/stream+json") + rw.WriteHeader(http.StatusOK) + rw.(http.Flusher).Flush() + + // The Go stdlib JSON encoder appends a newline character after message write. + encoder := json.NewEncoder(rw) + for _, workspaceHistoryLog := range workspaceHistoryLogs { + // JSON separated by a newline + err = encoder.Encode(convertWorkspaceHistoryLog(workspaceHistoryLog)) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("marshal: %s", err), + }) + return + } + rw.(http.Flusher).Flush() + } + + for { + select { + case <-r.Context().Done(): + return + case log := <-newLogNotify: + err = encoder.Encode(log) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("marshal follow: %s", err), + }) + return + } + rw.(http.Flusher).Flush() + } + } +} + // Converts the internal workspace representation to a public external-facing model. func convertWorkspace(workspace database.Workspace) Workspace { return Workspace(workspace) @@ -352,6 +505,7 @@ func convertWorkspaceHistory(workspaceHistory database.WorkspaceHistory) Workspa //nolint:unconvert return WorkspaceHistory(WorkspaceHistory{ ID: workspaceHistory.ID, + Name: workspaceHistory.Name, CreatedAt: workspaceHistory.CreatedAt, UpdatedAt: workspaceHistory.UpdatedAt, CompletedAt: workspaceHistory.CompletedAt.Time, @@ -363,3 +517,17 @@ func convertWorkspaceHistory(workspaceHistory database.WorkspaceHistory) Workspa Initiator: workspaceHistory.Initiator, }) } + +func convertWorkspaceHistoryLog(workspaceHistoryLog database.WorkspaceHistoryLog) WorkspaceHistoryLog { + return WorkspaceHistoryLog{ + ID: workspaceHistoryLog.ID, + CreatedAt: workspaceHistoryLog.CreatedAt, + Source: workspaceHistoryLog.Source, + Level: workspaceHistoryLog.Level, + Output: workspaceHistoryLog.Output, + } +} + +func workspaceHistoryLogsChannel(workspaceHistoryID uuid.UUID) string { + return fmt.Sprintf("workspace-history-logs:%s", workspaceHistoryID) +} diff --git a/codersdk/projects.go b/codersdk/projects.go index a075ebee084db..56a50228fdada 100644 --- a/codersdk/projects.go +++ b/codersdk/projects.go @@ -84,3 +84,41 @@ func (c *Client) CreateProjectHistory(ctx context.Context, organization, project var projectVersion coderd.ProjectHistory return projectVersion, json.NewDecoder(res.Body).Decode(&projectVersion) } + +func (c *Client) ProjectHistoryLogs(ctx context.Context, organization, project, history string) ([]coderd.ProjectHistoryLog, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s/history/%s/logs", organization, project, history), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var logs []coderd.ProjectHistoryLog + return logs, json.NewDecoder(res.Body).Decode(&logs) +} + +func (c *Client) FollowProjectHistoryLogs(ctx context.Context, organization, project, history string) (<-chan coderd.ProjectHistoryLog, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s/history/%s/logs?follow", organization, project, history), nil) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { + defer res.Body.Close() + return nil, readBodyAsError(res) + } + + logs := make(chan coderd.ProjectHistoryLog) + decoder := json.NewDecoder(res.Body) + go func() { + defer close(logs) + var log coderd.ProjectHistoryLog + for { + err = decoder.Decode(&log) + if err != nil { + return + } + } + }() + return logs, nil +} diff --git a/codersdk/provisioners.go b/codersdk/provisioners.go new file mode 100644 index 0000000000000..afafc58ed4bae --- /dev/null +++ b/codersdk/provisioners.go @@ -0,0 +1,51 @@ +package codersdk + +import ( + "context" + "encoding/json" + "net/http" + + "golang.org/x/xerrors" + "nhooyr.io/websocket" + + "github.com/hashicorp/yamux" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" +) + +func (c *Client) ProvisionerDaemons(ctx context.Context) ([]coderd.ProvisionerDaemon, error) { + res, err := c.request(ctx, http.MethodGet, "/api/v2/provisioners/daemons", nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var daemons []coderd.ProvisionerDaemon + return daemons, json.NewDecoder(res.Body).Decode(&daemons) +} + +// ProvisionerDaemonClient returns the gRPC service for a provisioner daemon implementation. +func (c *Client) ProvisionerDaemonClient(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + serverURL, err := c.url.Parse("/api/v2/provisionerd") + if err != nil { + return nil, xerrors.Errorf("parse url: %w", err) + } + conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ + HTTPClient: c.httpClient, + }) + if err != nil { + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) + } + session, err := yamux.Client(websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(session)), nil +} diff --git a/codersdk/provisioners_test.go b/codersdk/provisioners_test.go new file mode 100644 index 0000000000000..d87c21d60cdf7 --- /dev/null +++ b/codersdk/provisioners_test.go @@ -0,0 +1,20 @@ +package codersdk_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/coderdtest" +) + +func TestProvisioners(t *testing.T) { + t.Parallel() + t.Run("ListDaemons", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _, err := server.Client.ProvisionerDaemons(context.Background()) + require.NoError(t, err) + }) +} diff --git a/codersdk/workspaces.go b/codersdk/workspaces.go index 937f58e861b11..256aa244503e6 100644 --- a/codersdk/workspaces.go +++ b/codersdk/workspaces.go @@ -127,3 +127,33 @@ func (c *Client) CreateWorkspaceHistory(ctx context.Context, owner, workspace st var workspaceHistory coderd.WorkspaceHistory return workspaceHistory, json.NewDecoder(res.Body).Decode(&workspaceHistory) } + +func (c *Client) FollowWorkspaceHistoryLogs(ctx context.Context, owner, workspace, history string) (<-chan coderd.WorkspaceHistoryLog, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/%s/history/%s/logs?follow", owner, workspace, history), nil) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { + defer res.Body.Close() + return nil, readBodyAsError(res) + } + + logs := make(chan coderd.WorkspaceHistoryLog) + decoder := json.NewDecoder(res.Body) + go func() { + defer close(logs) + var log coderd.WorkspaceHistoryLog + for { + err = decoder.Decode(&log) + if err != nil { + return + } + select { + case <-ctx.Done(): + return + case logs <- log: + } + } + }() + return logs, nil +} diff --git a/database/migrations/000002_projects.up.sql b/database/migrations/000002_projects.up.sql index 6a021cc988707..9a0df4fb44d43 100644 --- a/database/migrations/000002_projects.up.sql +++ b/database/migrations/000002_projects.up.sql @@ -111,4 +111,4 @@ CREATE TABLE project_history_log ( source log_source NOT NULL, level log_level NOT NULL, output varchar(1024) NOT NULL -); \ No newline at end of file +); diff --git a/database/migrations/000003_workspaces.up.sql b/database/migrations/000003_workspaces.up.sql index 60fc1c0d9d8ab..7e7172483d7c1 100644 --- a/database/migrations/000003_workspaces.up.sql +++ b/database/migrations/000003_workspaces.up.sql @@ -73,4 +73,4 @@ CREATE TABLE workspace_history_log ( source log_source NOT NULL, level log_level NOT NULL, output varchar(1024) NOT NULL -); \ No newline at end of file +); diff --git a/go.mod b/go.mod index 45cf8c8cfca3a..5577cb8758c1f 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ replace github.com/hashicorp/terraform-config-inspect => github.com/kylecarbs/te require ( cdr.dev/slog v1.4.1 + github.com/coder/retry v1.3.0 + github.com/go-chi/chi v1.5.4 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -18,6 +20,7 @@ require ( github.com/hashicorp/go-version v1.4.0 github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f github.com/hashicorp/terraform-exec v0.15.0 + github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 github.com/justinas/nosurf v1.1.1 github.com/lib/pq v1.10.4 github.com/moby/moby v20.10.12+incompatible @@ -35,6 +38,7 @@ require ( golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 google.golang.org/protobuf v1.27.1 + nhooyr.io/websocket v1.8.7 storj.io/drpc v0.0.29 ) @@ -70,6 +74,7 @@ require ( github.com/hashicorp/terraform-json v0.13.0 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/klauspost/compress v1.13.6 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect diff --git a/go.sum b/go.sum index cf80329fa859d..05291d42a553c 100644 --- a/go.sum +++ b/go.sum @@ -232,6 +232,8 @@ github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWH github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/cockroach-go/v2 v2.1.1/go.mod h1:7NtUnP6eK+l6k483WSYNrq3Kb23bWV10IRV1TyeSpwM= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= +github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= @@ -430,7 +432,13 @@ github.com/gabriel-vasile/mimetype v1.4.0/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmx github.com/garyburd/redigo v0.0.0-20150301180006-535138d7bcd7/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs= +github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= @@ -467,10 +475,13 @@ github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.10.0 h1:I7mrTYv78z8k8VXa/qJlOlEXn/nBh+BF8dHX5nt/dr0= github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= @@ -504,6 +515,12 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gocql/gocql v0.0.0-20210515062232-b7ef815b4556/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= github.com/godbus/dbus v0.0.0-20151105175453-c7fdd8b5cd55/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= github.com/godbus/dbus v0.0.0-20180201030542-885f9cc04c9c/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= @@ -629,6 +646,8 @@ github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2z github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= @@ -688,6 +707,8 @@ github.com/hashicorp/serf v0.9.5/go.mod h1:UWDWwZeL5cuWDJdl0C6wrvrUwEqtQ4ZKBKKEN github.com/hashicorp/serf v0.9.6/go.mod h1:TXZNMjZQijwlDvp+r0b63xZ45H7JmCmgg4gpTwn9UV4= github.com/hashicorp/terraform-json v0.13.0 h1:Li9L+lKD1FO5RVFRM1mMMIBDoUHslOniyEi5CM+FWGY= github.com/hashicorp/terraform-json v0.13.0/go.mod h1:y5OdLBCT+rxbwnpxZs9kGL7R9ExU76+cpdY8zHwoazk= +github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 h1:xixZ2bWeofWV68J+x6AzmKuVM/JWCQwkWm6GW/MUR6I= +github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= @@ -763,6 +784,7 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= @@ -785,10 +807,12 @@ github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQL github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.9.5/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -814,6 +838,7 @@ github.com/kylecarbs/terraform-exec v0.15.1-0.20220129210610-65894a884c09/go.mod github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -897,9 +922,11 @@ github.com/moby/term v0.0.0-20201216013528-df9cb8a40635/go.mod h1:FBS0z0QWA44HXy github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 h1:dcztxKSvZ4Id8iPpHERQBbIJfabdt4wUm5qy3wOL2Zc= github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6/go.mod h1:E2VnQOmVuvZB6UYnnDB0qG5Nq/1tD9acaOpo6xmt0Kw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= @@ -1157,6 +1184,10 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1 github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/unrolled/secure v1.0.9 h1:BWRuEb1vDrBFFDdbCnKkof3gZ35I/bnHGyt0LB0TNyQ= github.com/unrolled/secure v1.0.9/go.mod h1:fO+mEan+FLB0CdEnHf6Q4ZZVNqG+5fuLFnP8p0BXDPI= github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= @@ -1929,6 +1960,8 @@ modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/z v1.0.1-0.20210308123920-1f282aa71362/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= modernc.org/z v1.0.1/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= modernc.org/zappy v1.0.0/go.mod h1:hHe+oGahLVII/aTTyWK/b53VDHMAGCBYYeZ9sn83HC4= +nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= +nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/httpmw/projecthistoryparam.go b/httpmw/projecthistoryparam.go new file mode 100644 index 0000000000000..48702163bc3dc --- /dev/null +++ b/httpmw/projecthistoryparam.go @@ -0,0 +1,60 @@ +package httpmw + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" +) + +type projectHistoryParamContextKey struct{} + +// ProjectHistoryParam returns the project history from the ExtractProjectHistoryParam handler. +func ProjectHistoryParam(r *http.Request) database.ProjectHistory { + projectHistory, ok := r.Context().Value(projectHistoryParamContextKey{}).(database.ProjectHistory) + if !ok { + panic("developer error: project history param middleware not provided") + } + return projectHistory +} + +// ExtractProjectHistoryParam grabs project history from the "projecthistory" URL parameter. +func ExtractProjectHistoryParam(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + project := ProjectParam(r) + projectHistoryName := chi.URLParam(r, "projecthistory") + if projectHistoryName == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "project history name must be provided", + }) + return + } + projectHistory, err := db.GetProjectHistoryByProjectIDAndName(r.Context(), database.GetProjectHistoryByProjectIDAndNameParams{ + ProjectID: project.ID, + Name: projectHistoryName, + }) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: fmt.Sprintf("project history %q does not exist", projectHistoryName), + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project history: %s", err.Error()), + }) + return + } + + ctx := context.WithValue(r.Context(), projectHistoryParamContextKey{}, projectHistory) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/httpmw/projecthistoryparam_test.go b/httpmw/projecthistoryparam_test.go new file mode 100644 index 0000000000000..c72b6fe37be66 --- /dev/null +++ b/httpmw/projecthistoryparam_test.go @@ -0,0 +1,161 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cryptorand" + "github.com/coder/coder/database" + "github.com/coder/coder/database/databasefake" + "github.com/coder/coder/httpmw" +) + +func TestProjectHistoryParam(t *testing.T) { + t.Parallel() + + setupAuthentication := func(db database.Store) (*http.Request, database.Project) { + var ( + id, secret = randomAPIKeyParts() + hashed = sha256.Sum256([]byte(secret)) + ) + r := httptest.NewRequest("GET", "/", nil) + r.AddCookie(&http.Cookie{ + Name: httpmw.AuthCookie, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + userID, err := cryptorand.String(16) + require.NoError(t, err) + username, err := cryptorand.String(8) + require.NoError(t, err) + user, err := db.InsertUser(r.Context(), database.InsertUserParams{ + ID: userID, + Email: "testaccount@coder.com", + Name: "example", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: hashed[:], + Username: username, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: id, + UserID: user.ID, + HashedSecret: hashed[:], + LastUsed: database.Now(), + ExpiresAt: database.Now().Add(time.Minute), + }) + require.NoError(t, err) + orgID, err := cryptorand.String(16) + require.NoError(t, err) + organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{ + ID: orgID, + Name: "banana", + Description: "wowie", + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + _, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ + OrganizationID: orgID, + UserID: user.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + project, err := db.InsertProject(context.Background(), database.InsertProjectParams{ + ID: uuid.New(), + OrganizationID: organization.ID, + Name: "moo", + }) + require.NoError(t, err) + + ctx := chi.NewRouteContext() + ctx.URLParams.Add("organization", organization.Name) + ctx.URLParams.Add("project", project.Name) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) + return r, project + } + + t.Run("None", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + httpmw.ExtractProjectParam(db), + httpmw.ExtractProjectHistoryParam(db), + ) + rtr.Get("/", nil) + r, _ := setupAuthentication(db) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + httpmw.ExtractProjectParam(db), + httpmw.ExtractProjectHistoryParam(db), + ) + rtr.Get("/", nil) + + r, _ := setupAuthentication(db) + chi.RouteContext(r.Context()).URLParams.Add("projecthistory", "nothin") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("ProjectHistory", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractOrganizationParam(db), + httpmw.ExtractProjectParam(db), + httpmw.ExtractProjectHistoryParam(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.ProjectHistoryParam(r) + rw.WriteHeader(http.StatusOK) + }) + + r, project := setupAuthentication(db) + projectHistory, err := db.InsertProjectHistory(context.Background(), database.InsertProjectHistoryParams{ + ID: uuid.New(), + ProjectID: project.ID, + Name: "moo", + }) + require.NoError(t, err) + chi.RouteContext(r.Context()).URLParams.Add("projecthistory", projectHistory.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) +} diff --git a/httpmw/workspacehistoryparam.go b/httpmw/workspacehistoryparam.go new file mode 100644 index 0000000000000..e210414290d21 --- /dev/null +++ b/httpmw/workspacehistoryparam.go @@ -0,0 +1,60 @@ +package httpmw + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" +) + +type workspaceHistoryParamContextKey struct{} + +// WorkspaceHistoryParam returns the workspace history from the ExtractWorkspaceHistoryParam handler. +func WorkspaceHistoryParam(r *http.Request) database.WorkspaceHistory { + workspaceHistory, ok := r.Context().Value(workspaceHistoryParamContextKey{}).(database.WorkspaceHistory) + if !ok { + panic("developer error: workspace history param middleware not provided") + } + return workspaceHistory +} + +// ExtractWorkspaceHistoryParam grabs workspace history from the "workspacehistory" URL parameter. +func ExtractWorkspaceHistoryParam(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + workspace := WorkspaceParam(r) + workspaceHistoryName := chi.URLParam(r, "workspacehistory") + if workspaceHistoryName == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "workspace history name must be provided", + }) + return + } + workspaceHistory, err := db.GetWorkspaceHistoryByWorkspaceIDAndName(r.Context(), database.GetWorkspaceHistoryByWorkspaceIDAndNameParams{ + WorkspaceID: workspace.ID, + Name: workspaceHistoryName, + }) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: fmt.Sprintf("workspace history %q does not exist", workspaceHistoryName), + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get workspace history: %s", err.Error()), + }) + return + } + + ctx := context.WithValue(r.Context(), workspaceHistoryParamContextKey{}, workspaceHistory) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/httpmw/workspacehistoryparam_test.go b/httpmw/workspacehistoryparam_test.go new file mode 100644 index 0000000000000..374a501eeabdd --- /dev/null +++ b/httpmw/workspacehistoryparam_test.go @@ -0,0 +1,145 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cryptorand" + "github.com/coder/coder/database" + "github.com/coder/coder/database/databasefake" + "github.com/coder/coder/httpmw" +) + +func TestWorkspaceHistoryParam(t *testing.T) { + t.Parallel() + + setupAuthentication := func(db database.Store) (*http.Request, database.Workspace) { + var ( + id, secret = randomAPIKeyParts() + hashed = sha256.Sum256([]byte(secret)) + ) + r := httptest.NewRequest("GET", "/", nil) + r.AddCookie(&http.Cookie{ + Name: httpmw.AuthCookie, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + userID, err := cryptorand.String(16) + require.NoError(t, err) + username, err := cryptorand.String(8) + require.NoError(t, err) + user, err := db.InsertUser(r.Context(), database.InsertUserParams{ + ID: userID, + Email: "testaccount@coder.com", + Name: "example", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: hashed[:], + Username: username, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + _, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: id, + UserID: user.ID, + HashedSecret: hashed[:], + LastUsed: database.Now(), + ExpiresAt: database.Now().Add(time.Minute), + }) + require.NoError(t, err) + workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{ + ID: uuid.New(), + ProjectID: uuid.New(), + OwnerID: user.ID, + Name: "potato", + }) + require.NoError(t, err) + + ctx := chi.NewRouteContext() + ctx.URLParams.Add("user", userID) + ctx.URLParams.Add("workspace", workspace.Name) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) + return r, workspace + } + + t.Run("None", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractUserParam(db), + httpmw.ExtractWorkspaceParam(db), + httpmw.ExtractWorkspaceHistoryParam(db), + ) + rtr.Get("/", nil) + r, _ := setupAuthentication(db) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractUserParam(db), + httpmw.ExtractWorkspaceParam(db), + httpmw.ExtractWorkspaceHistoryParam(db), + ) + rtr.Get("/", nil) + + r, _ := setupAuthentication(db) + chi.RouteContext(r.Context()).URLParams.Add("workspacehistory", "nothin") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("WorkspaceHistory", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractUserParam(db), + httpmw.ExtractWorkspaceParam(db), + httpmw.ExtractWorkspaceHistoryParam(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.WorkspaceHistoryParam(r) + rw.WriteHeader(http.StatusOK) + }) + + r, workspace := setupAuthentication(db) + workspaceHistory, err := db.InsertWorkspaceHistory(context.Background(), database.InsertWorkspaceHistoryParams{ + ID: uuid.New(), + WorkspaceID: workspace.ID, + Name: "moo", + }) + require.NoError(t, err) + chi.RouteContext(r.Context()).URLParams.Add("workspacehistory", workspaceHistory.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) +} diff --git a/peerbroker/dial_test.go b/peerbroker/dial_test.go index 30066b8d82397..537dc157e0a79 100644 --- a/peerbroker/dial_test.go +++ b/peerbroker/dial_test.go @@ -7,7 +7,6 @@ import ( "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" "go.uber.org/goleak" - "storj.io/drpc/drpcconn" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -37,7 +36,7 @@ func TestDial(t *testing.T) { }) require.NoError(t, err) - api := proto.NewDRPCPeerBrokerClient(drpcconn.New(client)) + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) stream, err := api.NegotiateConnection(ctx) require.NoError(t, err) clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ diff --git a/peerbroker/listen.go b/peerbroker/listen.go index fa68023689c7e..8e92fe7a7c82d 100644 --- a/peerbroker/listen.go +++ b/peerbroker/listen.go @@ -4,12 +4,12 @@ import ( "context" "errors" "io" + "net" "reflect" "sync" "github.com/pion/webrtc/v3" "golang.org/x/xerrors" - "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -19,7 +19,7 @@ import ( // Listen consumes the transport as the server-side of the PeerBroker dRPC service. // The Accept function must be serviced, or new connections will hang. -func Listen(transport drpc.Transport, opts *peer.ConnOptions) (*Listener, error) { +func Listen(connListener net.Listener, opts *peer.ConnOptions) (*Listener, error) { ctx, cancelFunc := context.WithCancel(context.Background()) listener := &Listener{ connectionChannel: make(chan *peer.Conn), @@ -39,7 +39,7 @@ func Listen(transport drpc.Transport, opts *peer.ConnOptions) (*Listener, error) } srv := drpcserver.New(mux) go func() { - err := srv.ServeOne(ctx, transport) + err := srv.Serve(ctx, connListener) _ = listener.closeWithError(err) }() diff --git a/peerbroker/listen_test.go b/peerbroker/listen_test.go index c66d8a480a176..81582a91d4b84 100644 --- a/peerbroker/listen_test.go +++ b/peerbroker/listen_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/peerbroker" "github.com/coder/coder/peerbroker/proto" @@ -27,7 +26,7 @@ func TestListen(t *testing.T) { listener, err := peerbroker.Listen(server, nil) require.NoError(t, err) - api := proto.NewDRPCPeerBrokerClient(drpcconn.New(client)) + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) stream, err := api.NegotiateConnection(ctx) require.NoError(t, err) clientConn, err := peerbroker.Dial(stream, nil, nil) diff --git a/provisioner/terraform/parse_test.go b/provisioner/terraform/parse_test.go index e678d1d36c674..9d5bec03338f8 100644 --- a/provisioner/terraform/parse_test.go +++ b/provisioner/terraform/parse_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" @@ -30,12 +29,12 @@ func TestParse(t *testing.T) { go func() { err := Serve(ctx, &ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) for _, testCase := range []struct { Name string diff --git a/provisioner/terraform/provision_test.go b/provisioner/terraform/provision_test.go index 07ac1bde9dace..27117daa8464a 100644 --- a/provisioner/terraform/provision_test.go +++ b/provisioner/terraform/provision_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "storj.io/drpc/drpcconn" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" @@ -29,12 +28,12 @@ func TestProvision(t *testing.T) { go func() { err := Serve(ctx, &ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) for _, testCase := range []struct { Name string diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go new file mode 100644 index 0000000000000..2ce59bd0810fb --- /dev/null +++ b/provisionerd/provisionerd.go @@ -0,0 +1,466 @@ +package provisionerd + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + "time" + + "go.uber.org/atomic" + + "cdr.dev/slog" + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/retry" +) + +// Dialer represents the function to create a daemon client connection. +type Dialer func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) + +// Provisioners maps provisioner ID to implementation. +type Provisioners map[string]sdkproto.DRPCProvisionerClient + +// Options provides customizations to the behavior of a provisioner daemon. +type Options struct { + Logger slog.Logger + + PollInterval time.Duration + Provisioners Provisioners + WorkDirectory string +} + +// New creates and starts a provisioner daemon. +func New(clientDialer Dialer, opts *Options) io.Closer { + if opts.PollInterval == 0 { + opts.PollInterval = 5 * time.Second + } + ctx, ctxCancel := context.WithCancel(context.Background()) + daemon := &provisionerDaemon{ + clientDialer: clientDialer, + opts: opts, + + closeContext: ctx, + closeContextCancel: ctxCancel, + closed: make(chan struct{}), + } + go daemon.connect() + return daemon +} + +type provisionerDaemon struct { + opts *Options + + clientDialer Dialer + connectMutex sync.Mutex + client proto.DRPCProvisionerDaemonClient + updateStream proto.DRPCProvisionerDaemon_UpdateJobClient + + closeContext context.Context + closeContextCancel context.CancelFunc + closed chan struct{} + closeMutex sync.Mutex + closeError error + + runningJob *proto.AcquiredJob + runningJobContext context.Context + runningJobContextCancel context.CancelFunc + runningJobMutex sync.Mutex + isRunningJob atomic.Bool +} + +// Connnect establishes a connection to coderd. +func (p *provisionerDaemon) connect() { + p.connectMutex.Lock() + defer p.connectMutex.Unlock() + + var err error + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(p.closeContext); { + p.client, err = p.clientDialer(p.closeContext) + if err != nil { + // Warn + p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue + } + p.updateStream, err = p.client.UpdateJob(p.closeContext) + if err != nil { + p.opts.Logger.Warn(context.Background(), "create update job stream", slog.Error(err)) + continue + } + p.opts.Logger.Debug(context.Background(), "connected") + break + } + + go func() { + if p.isClosed() { + return + } + select { + case <-p.closed: + return + case <-p.updateStream.Context().Done(): + // We use the update stream to detect when the connection + // has been interrupted. This works well, because logs need + // to buffer if a job is running in the background. + p.opts.Logger.Debug(context.Background(), "update stream ended", slog.Error(p.updateStream.Context().Err())) + p.connect() + } + }() + + go func() { + if p.isClosed() { + return + } + ticker := time.NewTicker(p.opts.PollInterval) + defer ticker.Stop() + for { + select { + case <-p.closed: + return + case <-p.updateStream.Context().Done(): + return + case <-ticker.C: + p.acquireJob() + } + } + }() +} + +func (p *provisionerDaemon) acquireJob() { + p.runningJobMutex.Lock() + defer p.runningJobMutex.Unlock() + if p.isRunningJob.Load() { + p.opts.Logger.Debug(context.Background(), "skipping acquire; job is already running") + return + } + var err error + p.runningJob, err = p.client.AcquireJob(p.closeContext, &proto.Empty{}) + if err != nil { + p.opts.Logger.Error(context.Background(), "acquire job", slog.Error(err)) + return + } + if p.runningJob.JobId == "" { + p.opts.Logger.Debug(context.Background(), "no jobs available") + return + } + p.runningJobContext, p.runningJobContextCancel = context.WithCancel(p.closeContext) + p.isRunningJob.Store(true) + + p.opts.Logger.Info(context.Background(), "acquired job", + slog.F("organization_name", p.runningJob.OrganizationName), + slog.F("project_name", p.runningJob.ProjectName), + slog.F("username", p.runningJob.UserName), + slog.F("provisioner", p.runningJob.Provisioner), + ) + + go p.runJob() +} + +func (p *provisionerDaemon) runJob() { + // It's safe to cast this ProvisionerType. This data is coming directly from coderd. + provisioner, hasProvisioner := p.opts.Provisioners[p.runningJob.Provisioner] + if !hasProvisioner { + p.cancelActiveJob(fmt.Sprintf("provisioner %q not registered", p.runningJob.Provisioner)) + return + } + defer func() { + // Cleanup the work directory after execution. + err := os.RemoveAll(p.opts.WorkDirectory) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("remove all from %q directory: %s", p.opts.WorkDirectory, err)) + return + } + p.opts.Logger.Debug(context.Background(), "cleaned up work directory") + }() + + err := os.MkdirAll(p.opts.WorkDirectory, 0600) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err)) + return + } + + p.opts.Logger.Info(context.Background(), "unpacking project source archive", slog.F("size_bytes", len(p.runningJob.ProjectSourceArchive))) + reader := tar.NewReader(bytes.NewBuffer(p.runningJob.ProjectSourceArchive)) + for { + header, err := reader.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + p.cancelActiveJob(fmt.Sprintf("read project source archive: %s", err)) + return + } + // #nosec + path := filepath.Join(p.opts.WorkDirectory, header.Name) + if !strings.HasPrefix(path, filepath.Clean(p.opts.WorkDirectory)) { + p.cancelActiveJob("tar attempts to target relative upper directory") + return + } + mode := header.FileInfo().Mode() + if mode == 0 { + mode = 0600 + } + switch header.Typeflag { + case tar.TypeDir: + err = os.MkdirAll(path, mode) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("mkdir %q: %s", path, err)) + return + } + p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", path)) + case tar.TypeReg: + file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, mode) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("create file %q: %s", path, err)) + return + } + // Max file size of 10MB. + size, err := io.CopyN(file, reader, (1<<20)*10) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + p.cancelActiveJob(fmt.Sprintf("copy file %q: %s", path, err)) + return + } + err = file.Close() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("close file %q: %s", path, err)) + return + } + p.opts.Logger.Debug(context.Background(), "extracted file", + slog.F("size_bytes", size), + slog.F("path", path), + slog.F("mode", mode), + ) + } + } + + switch jobType := p.runningJob.Type.(type) { + case *proto.AcquiredJob_ProjectImport_: + p.opts.Logger.Debug(context.Background(), "acquired job is project import", + slog.F("project_history_name", jobType.ProjectImport.ProjectHistoryName), + ) + + p.runProjectImport(provisioner, jobType) + case *proto.AcquiredJob_WorkspaceProvision_: + p.opts.Logger.Debug(context.Background(), "acquired job is workspace provision", + slog.F("workspace_name", jobType.WorkspaceProvision.WorkspaceName), + slog.F("state_length", len(jobType.WorkspaceProvision.State)), + slog.F("parameters", jobType.WorkspaceProvision.ParameterValues), + ) + + p.runWorkspaceProvision(provisioner, jobType) + default: + p.cancelActiveJob(fmt.Sprintf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(p.runningJob.Type).String())) + return + } + + p.opts.Logger.Info(context.Background(), "completed job") + p.isRunningJob.Store(false) +} + +func (p *provisionerDaemon) runProjectImport(provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_ProjectImport_) { + stream, err := provisioner.Parse(p.runningJobContext, &sdkproto.Parse_Request{ + Directory: p.opts.WorkDirectory, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("parse source: %s", err)) + return + } + defer stream.Close() + for { + msg, err := stream.Recv() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("recv parse source: %s", err)) + return + } + switch msgType := msg.Type.(type) { + case *sdkproto.Parse_Response_Log: + p.opts.Logger.Debug(context.Background(), "parse job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("project_history_id", job.ProjectImport.ProjectHistoryId), + ) + + err = p.updateStream.Send(&proto.JobUpdate{ + JobId: p.runningJob.JobId, + ProjectImportLogs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + }}, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("update job: %s", err)) + return + } + case *sdkproto.Parse_Response_Complete: + _, err = p.client.CompleteJob(p.runningJobContext, &proto.CompletedJob{ + JobId: p.runningJob.JobId, + Type: &proto.CompletedJob_ProjectImport_{ + ProjectImport: &proto.CompletedJob_ProjectImport{ + ParameterSchemas: msgType.Complete.ParameterSchemas, + }, + }, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + return + } + // Return so we stop looping! + return + default: + p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String())) + return + } + } +} + +func (p *provisionerDaemon) runWorkspaceProvision(provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_WorkspaceProvision_) { + stream, err := provisioner.Provision(p.closeContext, &sdkproto.Provision_Request{ + Directory: p.opts.WorkDirectory, + ParameterValues: job.WorkspaceProvision.ParameterValues, + State: job.WorkspaceProvision.State, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("provision: %s", err)) + return + } + defer stream.Close() + + for { + msg, err := stream.Recv() + if err != nil { + p.cancelActiveJob(fmt.Sprintf("recv workspace provision: %s", err)) + return + } + switch msgType := msg.Type.(type) { + case *sdkproto.Provision_Response_Log: + p.opts.Logger.Debug(context.Background(), "workspace provision job logged", + slog.F("level", msgType.Log.Level), + slog.F("output", msgType.Log.Output), + slog.F("workspace_history_id", job.WorkspaceProvision.WorkspaceHistoryId), + ) + + err = p.updateStream.Send(&proto.JobUpdate{ + JobId: p.runningJob.JobId, + WorkspaceProvisionLogs: []*proto.Log{{ + Source: proto.LogSource_PROVISIONER, + Level: msgType.Log.Level, + CreatedAt: time.Now().UTC().UnixMilli(), + Output: msgType.Log.Output, + }}, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("send job update: %s", err)) + return + } + case *sdkproto.Provision_Response_Complete: + p.opts.Logger.Info(context.Background(), "provision successful; marking job as complete", + slog.F("resource_count", len(msgType.Complete.Resources)), + slog.F("resources", msgType.Complete.Resources), + slog.F("state_length", len(msgType.Complete.State)), + ) + + // Complete job may need to be async if we disconnected... + // When we reconnect we can flush any of these cached values. + _, err = p.client.CompleteJob(p.closeContext, &proto.CompletedJob{ + JobId: p.runningJob.JobId, + Type: &proto.CompletedJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.CompletedJob_WorkspaceProvision{ + State: msgType.Complete.State, + Resources: msgType.Complete.Resources, + }, + }, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + return + } + // Return so we stop looping! + return + default: + p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String())) + return + } + } +} + +func (p *provisionerDaemon) cancelActiveJob(errMsg string) { + p.runningJobMutex.Lock() + defer p.runningJobMutex.Unlock() + if !p.isRunningJob.Load() { + p.opts.Logger.Warn(context.Background(), "skipping job cancel; none running", slog.F("error_message", errMsg)) + return + } + + p.opts.Logger.Info(context.Background(), "canceling running job", + slog.F("error_message", errMsg), + slog.F("job_id", p.runningJob.JobId), + ) + _, err := p.client.CancelJob(p.closeContext, &proto.CancelledJob{ + JobId: p.runningJob.JobId, + Error: fmt.Sprintf("provisioner daemon: %s", errMsg), + }) + if err != nil { + p.opts.Logger.Error(context.Background(), "couldn't cancel job", slog.Error(err)) + } + p.opts.Logger.Debug(context.Background(), "canceled running job") + p.runningJobContextCancel() + p.isRunningJob.Store(false) +} + +// isClosed returns whether the API is closed or not. +func (p *provisionerDaemon) isClosed() bool { + select { + case <-p.closed: + return true + default: + return false + } +} + +// Close ends the provisioner. It will mark any running jobs as canceled. +func (p *provisionerDaemon) Close() error { + return p.closeWithError(nil) +} + +// closeWithError closes the provisioner; subsequent reads/writes will return the error err. +func (p *provisionerDaemon) closeWithError(err error) error { + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.isClosed() { + return p.closeError + } + + if p.isRunningJob.Load() { + errMsg := "provisioner daemon was shutdown gracefully" + if err != nil { + errMsg = err.Error() + } + p.cancelActiveJob(errMsg) + } + + p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) + p.closeError = err + close(p.closed) + p.closeContextCancel() + + if p.updateStream != nil { + _ = p.client.DRPCConn().Close() + _ = p.updateStream.Close() + } + + return err +} diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go new file mode 100644 index 0000000000000..7b2e621ff4e33 --- /dev/null +++ b/provisionerd/provisionerd_test.go @@ -0,0 +1,488 @@ +package provisionerd_test + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "io" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/hashicorp/yamux" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "go.uber.org/goleak" + "storj.io/drpc" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/provisionerd" + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestProvisionerd(t *testing.T) { + t.Parallel() + + noopUpdateJob := func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + <-stream.Context().Done() + return nil + } + + t.Run("InstantClose", func(t *testing.T) { + t.Parallel() + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil + }, provisionerd.Provisioners{}) + require.NoError(t, closer.Close()) + }) + + t.Run("ConnectErrorClose", func(t *testing.T) { + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + defer close(completeChan) + return nil, errors.New("an error") + }, provisionerd.Provisioners{}) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("AcquireEmptyJob", func(t *testing.T) { + // The provisioner daemon is supposed to skip the job acquire if + // the job provided is empty. This is to show it successfully + // tried to get a job, but none were available. + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + acquireJobAttempt := 0 + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if acquireJobAttempt == 1 { + close(completeChan) + } + acquireJobAttempt++ + return &proto.AcquiredJob{}, nil + }, + updateJob: noopUpdateJob, + }), nil + }, provisionerd.Provisioners{}) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("CloseCancelsJob", func(t *testing.T) { + t.Parallel() + completeChan := make(chan struct{}) + var closer io.Closer + var closerMutex sync.Mutex + closerMutex.Lock() + closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: noopUpdateJob, + cancelJob: func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + close(completeChan) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + closerMutex.Lock() + defer closerMutex.Unlock() + return closer.Close() + }, + }), + }) + closerMutex.Unlock() + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("MaliciousTar", func(t *testing.T) { + // Ensures tars with "../../../etc/passwd" as the path + // are not allowed to run, and will fail the job. + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "../../../etc/passwd": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: noopUpdateJob, + cancelJob: func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + close(completeChan) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{}), + }) + <-completeChan + require.NoError(t, closer.Close()) + }) + + t.Run("ProjectImport", func(t *testing.T) { + t.Parallel() + var ( + didComplete atomic.Bool + didLog atomic.Bool + didAcquireJob atomic.Bool + ) + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if didAcquireJob.Load() { + close(completeChan) + return &proto.AcquiredJob{}, nil + } + didAcquireJob.Store(true) + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + msg, err := stream.Recv() + if err != nil { + return err + } + if len(msg.ProjectImportLogs) == 0 { + continue + } + + didLog.Store(true) + } + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + didComplete.Store(true) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt")) + require.NoError(t, err) + require.Equal(t, "content", string(data)) + + err = stream.Send(&sdkproto.Parse_Response{ + Type: &sdkproto.Parse_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_INFO, + Output: "hello", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Parse_Response{ + Type: &sdkproto.Parse_Response_Complete{ + Complete: &sdkproto.Parse_Complete{ + ParameterSchemas: []*sdkproto.ParameterSchema{}, + }, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + <-completeChan + require.True(t, didLog.Load()) + require.True(t, didComplete.Load()) + require.NoError(t, closer.Close()) + }) + + t.Run("WorkspaceProvision", func(t *testing.T) { + t.Parallel() + var ( + didComplete atomic.Bool + didLog atomic.Bool + didAcquireJob atomic.Bool + ) + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + if didAcquireJob.Load() { + close(completeChan) + return &proto.AcquiredJob{}, nil + } + didAcquireJob.Store(true) + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.AcquiredJob_WorkspaceProvision{}, + }, + }, nil + }, + updateJob: func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + msg, err := stream.Recv() + if err != nil { + return err + } + if len(msg.WorkspaceProvisionLogs) == 0 { + continue + } + + didLog.Store(true) + } + }, + completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + didComplete.Store(true) + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + provision: func(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error { + err := stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Log{ + Log: &sdkproto.Log{ + Level: sdkproto.LogLevel_DEBUG, + Output: "wow", + }, + }, + }) + require.NoError(t, err) + + err = stream.Send(&sdkproto.Provision_Response{ + Type: &sdkproto.Provision_Response_Complete{ + Complete: &sdkproto.Provision_Complete{}, + }, + }) + require.NoError(t, err) + return nil + }, + }), + }) + <-completeChan + require.True(t, didLog.Load()) + require.True(t, didComplete.Load()) + require.NoError(t, closer.Close()) + }) +} + +// Creates an in-memory tar of the files provided. +func createTar(t *testing.T, files map[string]string) []byte { + var buffer bytes.Buffer + writer := tar.NewWriter(&buffer) + for path, content := range files { + err := writer.WriteHeader(&tar.Header{ + Name: path, + Size: int64(len(content)), + }) + require.NoError(t, err) + + _, err = writer.Write([]byte(content)) + require.NoError(t, err) + } + + err := writer.Flush() + require.NoError(t, err) + return buffer.Bytes() +} + +// Creates a provisionerd implementation with the provided dialer and provisioners. +func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) io.Closer { + closer := provisionerd.New(dialer, &provisionerd.Options{ + Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), + PollInterval: 50 * time.Millisecond, + Provisioners: provisioners, + WorkDirectory: t.TempDir(), + }) + t.Cleanup(func() { + _ = closer.Close() + }) + return closer +} + +// Creates a provisionerd protobuf client that's connected +// to the server implementation provided. +func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient { + clientPipe, serverPipe := createTransports(t) + t.Cleanup(func() { + _ = clientPipe.Close() + _ = serverPipe.Close() + }) + mux := drpcmux.New() + err := proto.DRPCRegisterProvisionerDaemon(mux, &server) + require.NoError(t, err) + srv := drpcserver.New(mux) + go func() { + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + err := srv.Serve(ctx, serverPipe) + require.NoError(t, err) + }() + return proto.NewDRPCProvisionerDaemonClient(&multiplexedDRPC{clientPipe}) +} + +// Creates a provisioner protobuf client that's connected +// to the server implementation provided. +func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient { + clientPipe, serverPipe := createTransports(t) + t.Cleanup(func() { + _ = clientPipe.Close() + _ = serverPipe.Close() + }) + mux := drpcmux.New() + err := sdkproto.DRPCRegisterProvisioner(mux, &server) + require.NoError(t, err) + srv := drpcserver.New(mux) + go func() { + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + err := srv.Serve(ctx, serverPipe) + require.NoError(t, err) + }() + return sdkproto.NewDRPCProvisionerClient(&multiplexedDRPC{clientPipe}) +} + +type provisionerTestServer struct { + parse func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error + provision func(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error +} + +func (p *provisionerTestServer) Parse(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + return p.parse(request, stream) +} + +func (p *provisionerTestServer) Provision(request *sdkproto.Provision_Request, stream sdkproto.DRPCProvisioner_ProvisionStream) error { + return p.provision(request, stream) +} + +// Fulfills the protobuf interface for a ProvisionerDaemon with +// passable functions for dynamic functionality. +type provisionerDaemonTestServer struct { + acquireJob func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) + updateJob func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error + cancelJob func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) + completeJob func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) +} + +func (p *provisionerDaemonTestServer) AcquireJob(ctx context.Context, empty *proto.Empty) (*proto.AcquiredJob, error) { + return p.acquireJob(ctx, empty) +} + +func (p *provisionerDaemonTestServer) UpdateJob(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + return p.updateJob(stream) +} + +func (p *provisionerDaemonTestServer) CancelJob(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + return p.cancelJob(ctx, job) +} + +func (p *provisionerDaemonTestServer) CompleteJob(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { + return p.completeJob(ctx, job) +} + +// Creates an in-memory pipe of two yamux sessions. +func createTransports(t *testing.T) (*yamux.Session, *yamux.Session) { + clientReader, clientWriter := io.Pipe() + serverReader, serverWriter := io.Pipe() + yamuxConfig := yamux.DefaultConfig() + yamuxConfig.LogOutput = io.Discard + client, err := yamux.Client(&readWriteCloser{ + ReadCloser: clientReader, + Writer: serverWriter, + }, yamuxConfig) + require.NoError(t, err) + + server, err := yamux.Server(&readWriteCloser{ + ReadCloser: serverReader, + Writer: clientWriter, + }, yamuxConfig) + require.NoError(t, err) + t.Cleanup(func() { + _ = clientReader.Close() + _ = clientWriter.Close() + _ = serverReader.Close() + _ = serverWriter.Close() + _ = client.Close() + _ = server.Close() + }) + return client, server +} + +type readWriteCloser struct { + io.ReadCloser + io.Writer +} + +// Allows concurrent requests on a single dRPC connection. +// Required for calling functions concurrently. +type multiplexedDRPC struct { + session *yamux.Session +} + +func (m *multiplexedDRPC) Close() error { + return m.session.Close() +} + +func (m *multiplexedDRPC) Closed() <-chan struct{} { + return m.session.CloseChan() +} + +func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { + conn, err := m.session.Open() + if err != nil { + return err + } + return drpcconn.New(conn).Invoke(ctx, rpc, enc, in, out) +} + +func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { + conn, err := m.session.Open() + if err != nil { + return nil, err + } + return drpcconn.New(conn).NewStream(ctx, rpc, enc) +} diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index 9b7952001f96c..c278715a145e6 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -4,19 +4,22 @@ import ( "context" "errors" "io" + "net" + "os" "golang.org/x/xerrors" - "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" + "github.com/hashicorp/yamux" + "github.com/coder/coder/provisionersdk/proto" ) // ServeOptions are configurations to serve a provisioner. type ServeOptions struct { - // Transport specifies a custom transport to serve the dRPC connection. - Transport drpc.Transport + // Conn specifies a custom transport to serve the dRPC connection. + Listener net.Listener } // Serve starts a dRPC connection for the provisioner and transport provided. @@ -25,8 +28,17 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser options = &ServeOptions{} } // Default to using stdio. - if options.Transport == nil { - options.Transport = TransportStdio() + if options.Listener == nil { + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + stdio, err := yamux.Server(readWriteCloser{ + ReadCloser: os.Stdin, + Writer: os.Stdout, + }, config) + if err != nil { + return xerrors.Errorf("create yamux: %w", err) + } + options.Listener = stdio } // dRPC is a drop-in replacement for gRPC with less generated code, and faster transports. @@ -40,16 +52,12 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser // Only serve a single connection on the transport. // Transports are not multiplexed, and provisioners are // short-lived processes that can be executed concurrently. - err = srv.ServeOne(ctx, options.Transport) + err = srv.Serve(ctx, options.Listener) if err != nil { if errors.Is(err, context.Canceled) { return nil } - if errors.Is(err, io.ErrClosedPipe) { - // This may occur if the transport on either end is - // closed before the context. It's fine to return - // nil here, since the server has nothing to - // communicate with. + if errors.Is(err, yamux.ErrSessionShutdown) { return nil } return xerrors.Errorf("serve transport: %w", err) diff --git a/provisionersdk/serve_test.go b/provisionersdk/serve_test.go index 08ac393eb8dfc..cf2dd7517df82 100644 --- a/provisionersdk/serve_test.go +++ b/provisionersdk/serve_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "storj.io/drpc/drpcconn" "storj.io/drpc/drpcerr" "github.com/coder/coder/provisionersdk" @@ -29,12 +28,12 @@ func TestProvisionerSDK(t *testing.T) { defer cancelFunc() go func() { err := provisionersdk.Serve(ctx, &proto.DRPCProvisionerUnimplementedServer{}, &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }) require.NoError(t, err) }() - api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client)) stream, err := api.Parse(context.Background(), &proto.Parse_Request{}) require.NoError(t, err) _, err = stream.Recv() @@ -47,7 +46,7 @@ func TestProvisionerSDK(t *testing.T) { _ = server.Close() err := provisionersdk.Serve(context.Background(), &proto.DRPCProvisionerUnimplementedServer{}, &provisionersdk.ServeOptions{ - Transport: server, + Listener: server, }) require.NoError(t, err) }) diff --git a/provisionersdk/transport.go b/provisionersdk/transport.go index c01a7ab8269e9..7fd87839d174b 100644 --- a/provisionersdk/transport.go +++ b/provisionersdk/transport.go @@ -1,44 +1,74 @@ package provisionersdk import ( + "context" "io" - "os" + "github.com/hashicorp/yamux" "storj.io/drpc" + "storj.io/drpc/drpcconn" ) -// Transport creates a dRPC transport using stdin and stdout. -func TransportStdio() drpc.Transport { - return &transport{ - in: os.Stdin, - out: os.Stdout, +// TransportPipe creates an in-memory pipe for dRPC transport. +func TransportPipe() (*yamux.Session, *yamux.Session) { + clientReader, clientWriter := io.Pipe() + serverReader, serverWriter := io.Pipe() + yamuxConfig := yamux.DefaultConfig() + yamuxConfig.LogOutput = io.Discard + client, err := yamux.Client(&readWriteCloser{ + ReadCloser: clientReader, + Writer: serverWriter, + }, yamuxConfig) + if err != nil { + panic(err) + } + + server, err := yamux.Server(&readWriteCloser{ + ReadCloser: serverReader, + Writer: clientWriter, + }, yamuxConfig) + if err != nil { + panic(err) } + return client, server } -// TransportPipe creates an in-memory pipe for dRPC transport. -func TransportPipe() (drpc.Transport, drpc.Transport) { - clientReader, serverWriter := io.Pipe() - serverReader, clientWriter := io.Pipe() - clientTransport := &transport{clientReader, clientWriter} - serverTransport := &transport{serverReader, serverWriter} +// Conn returns a multiplexed dRPC connection from a yamux session. +func Conn(session *yamux.Session) drpc.Conn { + return &multiplexedDRPC{session} +} - return clientTransport, serverTransport +type readWriteCloser struct { + io.ReadCloser + io.Writer } -// transport wraps an input and output to pipe data. -type transport struct { - in io.ReadCloser - out io.Writer +// Allows concurrent requests on a single dRPC connection. +// Required for calling functions concurrently. +type multiplexedDRPC struct { + session *yamux.Session } -func (s *transport) Read(data []byte) (int, error) { - return s.in.Read(data) +func (m *multiplexedDRPC) Close() error { + return m.session.Close() } -func (s *transport) Write(data []byte) (int, error) { - return s.out.Write(data) +func (m *multiplexedDRPC) Closed() <-chan struct{} { + return m.session.CloseChan() } -func (s *transport) Close() error { - return s.in.Close() +func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { + conn, err := m.session.Open() + if err != nil { + return err + } + return drpcconn.New(conn).Invoke(ctx, rpc, enc, in, out) +} + +func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { + conn, err := m.session.Open() + if err != nil { + return nil, err + } + return drpcconn.New(conn).NewStream(ctx, rpc, enc) }