Skip to content

refactor: Generalize log ownership to allow for scratch jobs #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,21 @@ func New(options *Options) http.Handler {
r.Route("/{workspacehistory}", func(r chi.Router) {
r.Use(httpmw.ExtractWorkspaceHistoryParam(options.Database))
r.Get("/", api.workspaceHistoryByName)
r.Get("/logs", api.workspaceHistoryLogsByName)
})
})
})
})
})

r.Route("/provisioners/daemons", func(r chi.Router) {
r.Get("/", api.provisionerDaemons)
r.Get("/serve", api.provisionerDaemonsServe)
r.Route("/provisioners", func(r chi.Router) {
r.Route("/daemons", func(r chi.Router) {
r.Get("/", api.provisionerDaemons)
r.Get("/serve", api.provisionerDaemonsServe)
})
r.Route("/jobs/{provisionerjob}", func(r chi.Router) {
r.Use(httpmw.ExtractProvisionerJobParam(options.Database))
r.Get("/logs", api.provisionerJobLogsByID)
})
})
})
r.NotFound(site.Handler().ServeHTTP)
Expand Down
5 changes: 0 additions & 5 deletions coderd/projectversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ func (api *api) postProjectVersionByOrganization(rw http.ResponseWriter, r *http
InitiatorID: apiKey.UserID,
Provisioner: project.Provisioner,
Type: database.ProvisionerJobTypeProjectImport,
ProjectID: project.ID,
Input: input,
})
if err != nil {
Expand Down Expand Up @@ -249,7 +248,3 @@ func convertProjectParameter(parameter database.ProjectParameter) ProjectParamet
ValidationValueType: parameter.ValidationValueType,
}
}

func projectVersionLogsChannel(projectVersionID uuid.UUID) string {
return fmt.Sprintf("project-version-logs:%s", projectVersionID)
}
130 changes: 39 additions & 91 deletions coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,16 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty
return xerrors.Errorf("request job was invalidated: %s", errorMessage)
}

project, err := server.Database.GetProjectByID(ctx, job.ProjectID)
if err != nil {
return nil, failJob(fmt.Sprintf("get project: %s", err))
}
organization, err := server.Database.GetOrganizationByID(ctx, project.OrganizationID)
if err != nil {
return nil, failJob(fmt.Sprintf("get organization: %s", err))
}
user, err := server.Database.GetUserByID(ctx, job.InitiatorID)
if err != nil {
return nil, failJob(fmt.Sprintf("get user: %s", err))
}

protoJob := &proto.AcquiredJob{
JobId: job.ID.String(),
CreatedAt: job.CreatedAt.UnixMilli(),
Provisioner: string(job.Provisioner),
OrganizationName: organization.Name,
ProjectName: project.Name,
UserName: user.Username,
JobId: job.ID.String(),
CreatedAt: job.CreatedAt.UnixMilli(),
Provisioner: string(job.Provisioner),
UserName: user.Username,
}
var projectVersion database.ProjectVersion
switch job.Type {
Expand All @@ -206,6 +196,14 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty
if err != nil {
return nil, failJob(fmt.Sprintf("get project version: %s", err))
}
project, err := server.Database.GetProjectByID(ctx, projectVersion.ProjectID)
if err != nil {
return nil, failJob(fmt.Sprintf("get project: %s", err))
}
organization, err := server.Database.GetOrganizationByID(ctx, project.OrganizationID)
if err != nil {
return nil, failJob(fmt.Sprintf("get organization: %s", err))
}

// Compute parameters for the workspace to consume.
parameters, err := projectparameter.Compute(ctx, server.Database, projectparameter.Scope{
Expand Down Expand Up @@ -246,8 +244,8 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty

protoJob.Type = &proto.AcquiredJob_ProjectImport_{
ProjectImport: &proto.AcquiredJob_ProjectImport{
ProjectVersionId: projectVersion.ID.String(),
ProjectVersionName: projectVersion.Name,
// This will be replaced once the project import has been refactored.
ProjectName: "placeholder",
Comment on lines +247 to +248
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining this 😄

},
}
}
Expand Down Expand Up @@ -289,85 +287,35 @@ func (server *provisionerdServer) UpdateJob(stream proto.DRPCProvisionerDaemon_U
if err != nil {
return xerrors.Errorf("update job: %w", err)
}
switch job.Type {
case database.ProvisionerJobTypeProjectImport:
if len(update.ProjectImportLogs) == 0 {
continue
}
var input projectImportJob
err = json.Unmarshal(job.Input, &input)
if err != nil {
return xerrors.Errorf("unmarshal job input %q: %s", job.Input, err)
}
insertParams := database.InsertProjectVersionLogsParams{
ProjectVersionID: input.ProjectVersionID,
}
for _, log := range update.ProjectImportLogs {
logLevel, err := convertLogLevel(log.Level)
if err != nil {
return xerrors.Errorf("convert log level: %w", err)
}
logSource, err := convertLogSource(log.Source)
if err != nil {
return xerrors.Errorf("convert log source: %w", err)
}
insertParams.ID = append(insertParams.ID, uuid.New())
insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt))
insertParams.Level = append(insertParams.Level, logLevel)
insertParams.Source = append(insertParams.Source, logSource)
insertParams.Output = append(insertParams.Output, log.Output)
}
logs, err := server.Database.InsertProjectVersionLogs(stream.Context(), insertParams)
if err != nil {
return xerrors.Errorf("insert project logs: %w", err)
}
data, err := json.Marshal(logs)
if err != nil {
return xerrors.Errorf("marshal project log: %w", err)
}
err = server.Pubsub.Publish(projectVersionLogsChannel(input.ProjectVersionID), data)
if err != nil {
return xerrors.Errorf("publish history log: %w", err)
}
case database.ProvisionerJobTypeWorkspaceProvision:
if len(update.WorkspaceProvisionLogs) == 0 {
continue
}
var input workspaceProvisionJob
err = json.Unmarshal(job.Input, &input)
if err != nil {
return xerrors.Errorf("unmarshal job input %q: %s", job.Input, err)
}
insertParams := database.InsertWorkspaceHistoryLogsParams{
WorkspaceHistoryID: input.WorkspaceHistoryID,
}
for _, log := range update.WorkspaceProvisionLogs {
logLevel, err := convertLogLevel(log.Level)
if err != nil {
return xerrors.Errorf("convert log level: %w", err)
}
logSource, err := convertLogSource(log.Source)
if err != nil {
return xerrors.Errorf("convert log source: %w", err)
}
insertParams.ID = append(insertParams.ID, uuid.New())
insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt))
insertParams.Level = append(insertParams.Level, logLevel)
insertParams.Source = append(insertParams.Source, logSource)
insertParams.Output = append(insertParams.Output, log.Output)
}
logs, err := server.Database.InsertWorkspaceHistoryLogs(stream.Context(), insertParams)
if err != nil {
return xerrors.Errorf("insert workspace logs: %w", err)
}
data, err := json.Marshal(logs)
insertParams := database.InsertProvisionerJobLogsParams{
JobID: parsedID,
}
for _, log := range update.Logs {
logLevel, err := convertLogLevel(log.Level)
if err != nil {
return xerrors.Errorf("marshal project log: %w", err)
return xerrors.Errorf("convert log level: %w", err)
}
err = server.Pubsub.Publish(workspaceHistoryLogsChannel(input.WorkspaceHistoryID), data)
logSource, err := convertLogSource(log.Source)
if err != nil {
return xerrors.Errorf("publish history log: %w", err)
return xerrors.Errorf("convert log source: %w", err)
}
insertParams.ID = append(insertParams.ID, uuid.New())
insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt))
insertParams.Level = append(insertParams.Level, logLevel)
insertParams.Source = append(insertParams.Source, logSource)
insertParams.Output = append(insertParams.Output, log.Output)
}
logs, err := server.Database.InsertProvisionerJobLogs(stream.Context(), insertParams)
if err != nil {
return xerrors.Errorf("insert job logs: %w", err)
}
data, err := json.Marshal(logs)
if err != nil {
return xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(provisionerJobLogsChannel(parsedID), data)
if err != nil {
return xerrors.Errorf("publish job log: %w", err)
}
}
}
Expand Down
75 changes: 35 additions & 40 deletions coderd/workspacehistorylogs.go → coderd/provisionerjoblogs.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,28 @@ import (
"github.com/google/uuid"

"cdr.dev/slog"

"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)

// WorkspaceHistoryLog represents a single log from workspace history.
type WorkspaceHistoryLog struct {
// ProvisionerJobLog represents a single log from a provisioner job.
type ProvisionerJobLog struct {
ID uuid.UUID
CreatedAt time.Time `json:"created_at"`
Source database.LogSource `json:"log_source"`
Level database.LogLevel `json:"log_level"`
Output string `json:"output"`
}

// Returns workspace history logs based on query parameters.
// Returns provisioner logs based on query parameters.
// The intended usage for a client to stream all logs (with JS API):
// const timestamp = new Date().getTime();
// 1. GET /logs?before=<timestamp>
// 2. GET /logs?after=<timestamp>&follow
// The combination of these responses should provide all current logs
// to the consumer, and future logs are streamed in the follow request.
func (api *api) workspaceHistoryLogsByName(rw http.ResponseWriter, r *http.Request) {
func (api *api) provisionerJobLogsByID(rw http.ResponseWriter, r *http.Request) {
follow := r.URL.Query().Has("follow")
afterRaw := r.URL.Query().Get("after")
beforeRaw := r.URL.Query().Get("before")
Expand Down Expand Up @@ -78,36 +77,36 @@ func (api *api) workspaceHistoryLogsByName(rw http.ResponseWriter, r *http.Reque
before = database.Now()
}

history := httpmw.WorkspaceHistoryParam(r)
job := httpmw.ProvisionerJobParam(r)
if !follow {
logs, err := api.Database.GetWorkspaceHistoryLogsByIDBetween(r.Context(), database.GetWorkspaceHistoryLogsByIDBetweenParams{
WorkspaceHistoryID: history.ID,
CreatedAfter: after,
CreatedBefore: before,
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace history: %s", err),
Message: fmt.Sprintf("get provisioner logs: %s", err),
})
return
}
if logs == nil {
logs = []database.WorkspaceHistoryLog{}
logs = []database.ProvisionerJobLog{}
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, logs)
return
}

bufferedLogs := make(chan database.WorkspaceHistoryLog, 128)
closeSubscribe, err := api.Pubsub.Subscribe(workspaceHistoryLogsChannel(history.ID), func(ctx context.Context, message []byte) {
var logs []database.WorkspaceHistoryLog
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(job.ID), func(ctx context.Context, message []byte) {
var logs []database.ProvisionerJobLog
err := json.Unmarshal(message, &logs)
if err != nil {
api.Logger.Warn(r.Context(), fmt.Sprintf("invalid workspace log on channel %q: %s", workspaceHistoryLogsChannel(history.ID), err.Error()))
api.Logger.Warn(r.Context(), fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error()))
return
}

Expand All @@ -117,30 +116,30 @@ func (api *api) workspaceHistoryLogsByName(rw http.ResponseWriter, r *http.Reque
default:
// If this overflows users could miss logs streaming. This can happen
// if a database request takes a long amount of time, and we get a lot of logs.
api.Logger.Warn(r.Context(), "workspace history log overflowing channel")
api.Logger.Warn(r.Context(), "provisioner job log overflowing channel")
}
}
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("subscribe to workspace history logs: %s", err),
Message: fmt.Sprintf("subscribe to provisioner job logs: %s", err),
})
return
}
defer closeSubscribe()

workspaceHistoryLogs, err := api.Database.GetWorkspaceHistoryLogsByIDBetween(r.Context(), database.GetWorkspaceHistoryLogsByIDBetweenParams{
WorkspaceHistoryID: history.ID,
CreatedAfter: after,
CreatedBefore: before,
provisionerJobLogs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
workspaceHistoryLogs = []database.WorkspaceHistoryLog{}
provisionerJobLogs = []database.ProvisionerJobLog{}
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprint("get workspace history logs: %w", err),
Message: fmt.Sprint("get provisioner job logs: %w", err),
})
return
}
Expand All @@ -154,8 +153,8 @@ func (api *api) workspaceHistoryLogsByName(rw http.ResponseWriter, r *http.Reque
// The Go stdlib JSON encoder appends a newline character after message write.
encoder := json.NewEncoder(rw)

for _, workspaceHistoryLog := range workspaceHistoryLogs {
err = encoder.Encode(convertWorkspaceHistoryLog(workspaceHistoryLog))
for _, provisionerJobLog := range provisionerJobLogs {
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
if err != nil {
return
}
Expand All @@ -168,15 +167,15 @@ func (api *api) workspaceHistoryLogsByName(rw http.ResponseWriter, r *http.Reque
case <-r.Context().Done():
return
case log := <-bufferedLogs:
err = encoder.Encode(convertWorkspaceHistoryLog(log))
err = encoder.Encode(convertProvisionerJobLog(log))
if err != nil {
return
}
rw.(http.Flusher).Flush()
case <-ticker.C:
job, err := api.Database.GetProvisionerJobByID(r.Context(), history.ProvisionJobID)
job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID)
if err != nil {
api.Logger.Warn(r.Context(), "streaming workspace logs; checking if job completed", slog.Error(err), slog.F("job_id", history.ProvisionJobID))
api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String()))
continue
}
if convertProvisionerJob(job).Status.Completed() {
Expand All @@ -186,16 +185,12 @@ func (api *api) workspaceHistoryLogsByName(rw http.ResponseWriter, r *http.Reque
}
}

func convertWorkspaceHistoryLog(workspaceHistoryLog database.WorkspaceHistoryLog) WorkspaceHistoryLog {
return WorkspaceHistoryLog{
ID: workspaceHistoryLog.ID,
CreatedAt: workspaceHistoryLog.CreatedAt,
Source: workspaceHistoryLog.Source,
Level: workspaceHistoryLog.Level,
Output: workspaceHistoryLog.Output,
func convertProvisionerJobLog(provisionerJobLog database.ProvisionerJobLog) ProvisionerJobLog {
return ProvisionerJobLog{
ID: provisionerJobLog.ID,
CreatedAt: provisionerJobLog.CreatedAt,
Source: provisionerJobLog.Source,
Level: provisionerJobLog.Level,
Output: provisionerJobLog.Output,
}
}

func workspaceHistoryLogsChannel(workspaceHistoryID uuid.UUID) string {
return fmt.Sprintf("workspace-history-logs:%s", workspaceHistoryID)
}
Loading