diff --git a/cli/server.go b/cli/server.go index 5a15681d23f35..3e48b2d318b7e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -78,6 +78,7 @@ import ( "github.com/coder/coder/coderd/schedule" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" + "github.com/coder/coder/coderd/unhanger" "github.com/coder/coder/coderd/updatecheck" "github.com/coder/coder/coderd/util/slice" "github.com/coder/coder/coderd/workspaceapps" @@ -898,11 +899,17 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. return xerrors.Errorf("notify systemd: %w", err) } - autobuildPoller := time.NewTicker(cfg.AutobuildPollInterval.Value()) - defer autobuildPoller.Stop() - autobuildExecutor := autobuild.NewExecutor(ctx, options.Database, coderAPI.TemplateScheduleStore, logger, autobuildPoller.C) + autobuildTicker := time.NewTicker(cfg.AutobuildPollInterval.Value()) + defer autobuildTicker.Stop() + autobuildExecutor := autobuild.NewExecutor(ctx, options.Database, coderAPI.TemplateScheduleStore, logger, autobuildTicker.C) autobuildExecutor.Run() + hangDetectorTicker := time.NewTicker(cfg.JobHangDetectorInterval.Value()) + defer hangDetectorTicker.Stop() + hangDetector := unhanger.New(ctx, options.Database, options.Pubsub, logger, hangDetectorTicker.C) + hangDetector.Start() + defer hangDetector.Close() + // Currently there is no way to ask the server to shut // itself down, so any exit signal will result in a non-zero // exit of the server. diff --git a/cli/testdata/server-config.yaml.golden b/cli/testdata/server-config.yaml.golden index d407d6ded8778..bfd9ed467bca8 100644 --- a/cli/testdata/server-config.yaml.golden +++ b/cli/testdata/server-config.yaml.golden @@ -148,6 +148,9 @@ networking: # Interval to poll for scheduled workspace builds. # (default: 1m0s, type: duration) autobuildPollInterval: 1m0s +# Interval to poll for hung jobs and automatically terminate them. +# (default: 1m0s, type: duration) +jobHangDetectorInterval: 1m0s introspection: prometheus: # Serve prometheus metrics on the address defined by prometheus address. diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index a49247c7abf9c..872fd022878cf 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -7239,6 +7239,9 @@ const docTemplate = `{ "in_memory_database": { "type": "boolean" }, + "job_hang_detector_interval": { + "type": "integer" + }, "logging": { "$ref": "#/definitions/codersdk.LoggingConfig" }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 946e6e4455d53..56db90e9f26e8 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -6468,6 +6468,9 @@ "in_memory_database": { "type": "boolean" }, + "job_hang_detector_interval": { + "type": "integer" + }, "logging": { "$ref": "#/definitions/codersdk.LoggingConfig" }, diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index e8a45dd7c77d5..f4bf035311e6a 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -68,6 +68,7 @@ import ( "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/schedule" "github.com/coder/coder/coderd/telemetry" + "github.com/coder/coder/coderd/unhanger" "github.com/coder/coder/coderd/updatecheck" "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/coderd/workspaceapps" @@ -256,6 +257,12 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can ).WithStatsChannel(options.AutobuildStats) lifecycleExecutor.Run() + hangDetectorTicker := time.NewTicker(options.DeploymentValues.JobHangDetectorInterval.Value()) + defer hangDetectorTicker.Stop() + hangDetector := unhanger.New(ctx, options.Database, options.Pubsub, slogtest.Make(t, nil).Named("unhanger.detector"), hangDetectorTicker.C) + hangDetector.Start() + t.Cleanup(hangDetector.Close) + var mutex sync.RWMutex var handler http.Handler srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/coderd/database/db.go b/coderd/database/db.go index bcf4de9a35540..9ad12340705ba 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -18,13 +18,6 @@ import ( "golang.org/x/xerrors" ) -// Well-known lock IDs for lock functions in the database. These should not -// change. If locks are deprecated, they should be kept to avoid reusing the -// same ID. -const ( - LockIDDeploymentSetup = iota + 1 -) - // Store contains all queryable database functions. // It extends the generated interface to add transaction support. type Store interface { diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 7b3b8310eaf3c..50bfbad070909 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -3,7 +3,6 @@ package db2sdk import ( "encoding/json" - "time" "github.com/google/uuid" @@ -81,6 +80,9 @@ func TemplateVersionParameter(param database.TemplateVersionParameter) (codersdk } func ProvisionerJobStatus(provisionerJob database.ProvisionerJob) codersdk.ProvisionerJobStatus { + // The case where jobs are hung is handled by the unhang package. We can't + // just return Failed here when it's hung because that doesn't reflect in + // the database. switch { case provisionerJob.CanceledAt.Valid: if !provisionerJob.CompletedAt.Valid { @@ -97,8 +99,6 @@ func ProvisionerJobStatus(provisionerJob database.ProvisionerJob) codersdk.Provi return codersdk.ProvisionerJobSucceeded } return codersdk.ProvisionerJobFailed - case database.Now().Sub(provisionerJob.UpdatedAt) > 30*time.Second: - return codersdk.ProvisionerJobFailed default: return codersdk.ProvisionerJobRunning } diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index cb8a7d28345ae..39020e64e9828 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -96,17 +96,6 @@ func TestProvisionerJobStatus(t *testing.T) { }, status: codersdk.ProvisionerJobFailed, }, - { - name: "not_updated", - job: database.ProvisionerJob{ - StartedAt: sql.NullTime{ - Time: database.Now().Add(-time.Minute), - Valid: true, - }, - UpdatedAt: database.Now().Add(-31 * time.Second), - }, - status: codersdk.ProvisionerJobFailed, - }, { name: "updated", job: database.ProvisionerJob{ diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index b67fa507c5772..fec2cc01a5edc 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -176,6 +176,25 @@ var ( Scope: rbac.ScopeAll, }.WithCachedASTValue() + // See unhanger package. + subjectHangDetector = rbac.Subject{ + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Name: "hangdetector", + DisplayName: "Hang Detector Daemon", + Site: rbac.Permissions(map[string][]rbac.Action{ + rbac.ResourceSystem.Type: {rbac.WildcardSymbol}, + rbac.ResourceTemplate.Type: {rbac.ActionRead}, + rbac.ResourceWorkspace.Type: {rbac.ActionRead, rbac.ActionUpdate}, + }), + Org: map[string][]rbac.Permission{}, + User: []rbac.Permission{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() + subjectSystemRestricted = rbac.Subject{ ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -217,6 +236,12 @@ func AsAutostart(ctx context.Context) context.Context { return context.WithValue(ctx, authContextKey{}, subjectAutostart) } +// AsHangDetector returns a context with an actor that has permissions required +// for unhanger.Detector to function. +func AsHangDetector(ctx context.Context) context.Context { + return context.WithValue(ctx, authContextKey{}, subjectHangDetector) +} + // AsSystemRestricted returns a context with an actor that has permissions // required for various system operations (login, logout, metrics cache). func AsSystemRestricted(ctx context.Context) context.Context { @@ -950,6 +975,14 @@ func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationID)(ctx, organizationID) } +// TODO: We need to create a ProvisionerJob resource type +func (q *querier) GetHungProvisionerJobs(ctx context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) { + // if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + // return nil, err + // } + return q.db.GetHungProvisionerJobs(ctx, hungSince) +} + func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { return "", err diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 1821d1906414f..f43e33b33772b 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -1753,6 +1753,19 @@ func (q *fakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationI return groups, nil } +func (q *fakeQuerier) GetHungProvisionerJobs(_ context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + hungJobs := []database.ProvisionerJob{} + for _, provisionerJob := range q.provisionerJobs { + if provisionerJob.StartedAt.Valid && !provisionerJob.CompletedAt.Valid && provisionerJob.UpdatedAt.Before(hungSince) { + hungJobs = append(hungJobs, provisionerJob) + } + } + return hungJobs, nil +} + func (q *fakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2135,7 +2148,7 @@ func (q *fakeQuerier) GetProvisionerLogsAfterID(_ context.Context, arg database. if jobLog.JobID != arg.JobID { continue } - if arg.CreatedAfter != 0 && jobLog.ID < arg.CreatedAfter { + if jobLog.ID <= arg.CreatedAfter { continue } logs = append(logs, jobLog) diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 02fccaf82e2ac..bf5b4e562182e 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -399,6 +399,13 @@ func (m metricsStore) GetGroupsByOrganizationID(ctx context.Context, organizatio return groups, err } +func (m metricsStore) GetHungProvisionerJobs(ctx context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) { + start := time.Now() + jobs, err := m.s.GetHungProvisionerJobs(ctx, hungSince) + m.queryLatencies.WithLabelValues("GetHungProvisionerJobs").Observe(time.Since(start).Seconds()) + return jobs, err +} + func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) { start := time.Now() version, err := m.s.GetLastUpdateCheck(ctx) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 69c5dd3f63c77..bd39a4bc315d1 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -701,6 +701,21 @@ func (mr *MockStoreMockRecorder) GetGroupsByOrganizationID(arg0, arg1 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsByOrganizationID", reflect.TypeOf((*MockStore)(nil).GetGroupsByOrganizationID), arg0, arg1) } +// GetHungProvisionerJobs mocks base method. +func (m *MockStore) GetHungProvisionerJobs(arg0 context.Context, arg1 time.Time) ([]database.ProvisionerJob, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHungProvisionerJobs", arg0, arg1) + ret0, _ := ret[0].([]database.ProvisionerJob) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHungProvisionerJobs indicates an expected call of GetHungProvisionerJobs. +func (mr *MockStoreMockRecorder) GetHungProvisionerJobs(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHungProvisionerJobs", reflect.TypeOf((*MockStore)(nil).GetHungProvisionerJobs), arg0, arg1) +} + // GetLastUpdateCheck mocks base method. func (m *MockStore) GetLastUpdateCheck(arg0 context.Context) (string, error) { m.ctrl.T.Helper() diff --git a/coderd/database/lock.go b/coderd/database/lock.go new file mode 100644 index 0000000000000..a17903e4a7b8b --- /dev/null +++ b/coderd/database/lock.go @@ -0,0 +1,19 @@ +package database + +import "hash/fnv" + +// Well-known lock IDs for lock functions in the database. These should not +// change. If locks are deprecated, they should be kept in this list to avoid +// reusing the same ID. +const ( + // Keep the unused iota here so we don't need + 1 every time + lockIDUnused = iota + LockIDDeploymentSetup +) + +// GenLockID generates a unique and consistent lock ID from a given string. +func GenLockID(name string) int64 { + hash := fnv.New64() + _, _ = hash.Write([]byte(name)) + return int64(hash.Sum64()) +} diff --git a/coderd/database/querier.go b/coderd/database/querier.go index e14a73cc80675..afe8c742dfed4 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -16,8 +16,6 @@ type sqlcQuerier interface { // // This must be called from within a transaction. The lock will be automatically // released when the transaction ends. - // - // Use database.LockID() to generate a unique lock ID from a string. AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error // Acquires the lock for a single job that isn't started, completed, // canceled, and that matches an array of provisioner types. @@ -75,6 +73,7 @@ type sqlcQuerier interface { GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error) + GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) GetLastUpdateCheck(ctx context.Context) (string, error) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error) GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error) @@ -217,8 +216,6 @@ type sqlcQuerier interface { // // This must be called from within a transaction. The lock will be automatically // released when the transaction ends. - // - // Use database.LockID() to generate a unique lock ID from a string. TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLinkParams) (GitAuthLink, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index ac81fac224dc2..f6b9a2bc05593 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1527,8 +1527,6 @@ SELECT pg_advisory_xact_lock($1) // // This must be called from within a transaction. The lock will be automatically // released when the transaction ends. -// -// Use database.LockID() to generate a unique lock ID from a string. func (q *sqlQuerier) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error { _, err := q.db.ExecContext(ctx, acquireLock, pgAdvisoryXactLock) return err @@ -1542,8 +1540,6 @@ SELECT pg_try_advisory_xact_lock($1) // // This must be called from within a transaction. The lock will be automatically // released when the transaction ends. -// -// Use database.LockID() to generate a unique lock ID from a string. func (q *sqlQuerier) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) { row := q.db.QueryRowContext(ctx, tryAcquireLock, pgTryAdvisoryXactLock) var pg_try_advisory_xact_lock bool @@ -2201,6 +2197,59 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi return i, err } +const getHungProvisionerJobs = `-- name: GetHungProvisionerJobs :many +SELECT + id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags, error_code, trace_metadata +FROM + provisioner_jobs +WHERE + updated_at < $1 + AND started_at IS NOT NULL + AND completed_at IS NULL +` + +func (q *sqlQuerier) GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) { + rows, err := q.db.QueryContext(ctx, getHungProvisionerJobs, updatedAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ProvisionerJob + for rows.Next() { + var i ProvisionerJob + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.StartedAt, + &i.CanceledAt, + &i.CompletedAt, + &i.Error, + &i.OrganizationID, + &i.InitiatorID, + &i.Provisioner, + &i.StorageMethod, + &i.Type, + &i.Input, + &i.WorkerID, + &i.FileID, + &i.Tags, + &i.ErrorCode, + &i.TraceMetadata, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getProvisionerJobByID = `-- name: GetProvisionerJobByID :one SELECT id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags, error_code, trace_metadata diff --git a/coderd/database/queries/lock.sql b/coderd/database/queries/lock.sql index 2421c73edae1b..0cf8ee0603690 100644 --- a/coderd/database/queries/lock.sql +++ b/coderd/database/queries/lock.sql @@ -3,8 +3,6 @@ -- -- This must be called from within a transaction. The lock will be automatically -- released when the transaction ends. --- --- Use database.LockID() to generate a unique lock ID from a string. SELECT pg_advisory_xact_lock($1); -- name: TryAcquireLock :one @@ -12,6 +10,4 @@ SELECT pg_advisory_xact_lock($1); -- -- This must be called from within a transaction. The lock will be automatically -- released when the transaction ends. --- --- Use database.LockID() to generate a unique lock ID from a string. SELECT pg_try_advisory_xact_lock($1); diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index d2619cf5e9fdb..b4c113c888dd4 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -128,3 +128,13 @@ SET error_code = $5 WHERE id = $1; + +-- name: GetHungProvisionerJobs :many +SELECT + * +FROM + provisioner_jobs +WHERE + updated_at < $1 + AND started_at IS NOT NULL + AND completed_at IS NULL; diff --git a/coderd/unhanger/detector.go b/coderd/unhanger/detector.go new file mode 100644 index 0000000000000..a0b61483ab04c --- /dev/null +++ b/coderd/unhanger/detector.go @@ -0,0 +1,363 @@ +package unhanger + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math/rand" //#nosec // this is only used for shuffling an array to pick random jobs to unhang + "time" + + "golang.org/x/xerrors" + + "github.com/google/uuid" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/db2sdk" + "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/database/pubsub" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisionersdk" +) + +const ( + // HungJobDuration is the duration of time since the last update to a job + // before it is considered hung. + HungJobDuration = 5 * time.Minute + + // HungJobExitTimeout is the duration of time that provisioners should allow + // for a graceful exit upon cancellation due to failing to send an update to + // a job. + // + // Provisioners should avoid keeping a job "running" for longer than this + // time after failing to send an update to the job. + HungJobExitTimeout = 3 * time.Minute + + // MaxJobsPerRun is the maximum number of hung jobs that the detector will + // terminate in a single run. + MaxJobsPerRun = 10 +) + +// HungJobLogMessages are written to provisioner job logs when a job is hung and +// terminated. +var HungJobLogMessages = []string{ + "", + "====================", + "Coder: Build has been detected as hung for 5 minutes and will be terminated.", + "====================", + "", +} + +// acquireLockError is returned when the detector fails to acquire a lock and +// cancels the current run. +type acquireLockError struct{} + +// Error implements error. +func (acquireLockError) Error() string { + return "lock is held by another client" +} + +// jobInelligibleError is returned when a job is not eligible to be terminated +// anymore. +type jobInelligibleError struct { + Err error +} + +// Error implements error. +func (e jobInelligibleError) Error() string { + return fmt.Sprintf("job is no longer eligible to be terminated: %s", e.Err) +} + +// Detector automatically detects hung provisioner jobs, sends messages into the +// build log and terminates them as failed. +type Detector struct { + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + db database.Store + pubsub pubsub.Pubsub + log slog.Logger + tick <-chan time.Time + stats chan<- Stats +} + +// Stats contains statistics about the last run of the detector. +type Stats struct { + // TerminatedJobIDs contains the IDs of all jobs that were detected as hung and + // terminated. + TerminatedJobIDs []uuid.UUID + // Error is the fatal error that occurred during the last run of the + // detector, if any. Error may be set to AcquireLockError if the detector + // failed to acquire a lock. + Error error +} + +// New returns a new hang detector. +func New(ctx context.Context, db database.Store, pub pubsub.Pubsub, log slog.Logger, tick <-chan time.Time) *Detector { + //nolint:gocritic // Hang detector has a limited set of permissions. + ctx, cancel := context.WithCancel(dbauthz.AsHangDetector(ctx)) + d := &Detector{ + ctx: ctx, + cancel: cancel, + done: make(chan struct{}), + db: db, + pubsub: pub, + log: log, + tick: tick, + stats: nil, + } + return d +} + +// WithStatsChannel will cause Executor to push a RunStats to ch after +// every tick. This push is blocking, so if ch is not read, the detector will +// hang. This should only be used in tests. +func (d *Detector) WithStatsChannel(ch chan<- Stats) *Detector { + d.stats = ch + return d +} + +// Start will cause the detector to detect and unhang provisioner jobs on every +// tick from its channel. It will stop when its context is Done, or when its +// channel is closed. +// +// Start should only be called once. +func (d *Detector) Start() { + go func() { + defer close(d.done) + defer d.cancel() + + for { + select { + case <-d.ctx.Done(): + return + case t, ok := <-d.tick: + if !ok { + return + } + stats := d.run(t) + if stats.Error != nil && !xerrors.As(stats.Error, &acquireLockError{}) { + d.log.Warn(d.ctx, "error running workspace build hang detector once", slog.Error(stats.Error)) + } + if len(stats.TerminatedJobIDs) != 0 { + d.log.Warn(d.ctx, "detected (and terminated) hung provisioner jobs", slog.F("job_ids", stats.TerminatedJobIDs)) + } + if d.stats != nil { + select { + case <-d.ctx.Done(): + return + case d.stats <- stats: + } + } + } + } + }() +} + +// Wait will block until the detector is stopped. +func (d *Detector) Wait() { + <-d.done +} + +// Close will stop the detector. +func (d *Detector) Close() { + d.cancel() + <-d.done +} + +func (d *Detector) run(t time.Time) Stats { + ctx, cancel := context.WithTimeout(d.ctx, 5*time.Minute) + defer cancel() + + stats := Stats{ + TerminatedJobIDs: []uuid.UUID{}, + Error: nil, + } + + // Find all provisioner jobs that are currently running but have not + // received an update in the last 5 minutes. + jobs, err := d.db.GetHungProvisionerJobs(ctx, t.Add(-HungJobDuration)) + if err != nil { + stats.Error = xerrors.Errorf("get hung provisioner jobs: %w", err) + return stats + } + + // Limit the number of jobs we'll unhang in a single run to avoid + // timing out. + if len(jobs) > MaxJobsPerRun { + // Pick a random subset of the jobs to unhang. + rand.Shuffle(len(jobs), func(i, j int) { + jobs[i], jobs[j] = jobs[j], jobs[i] + }) + jobs = jobs[:MaxJobsPerRun] + } + + // Send a message into the build log for each hung job saying that it + // has been detected and will be terminated, then mark the job as + // failed. + for _, job := range jobs { + log := d.log.With(slog.F("job_id", job.ID)) + + err := unhangJob(ctx, log, d.db, d.pubsub, job.ID) + if err != nil && !(xerrors.As(err, &acquireLockError{}) || xerrors.As(err, &jobInelligibleError{})) { + log.Error(ctx, "error forcefully terminating hung provisioner job", slog.Error(err)) + continue + } + + stats.TerminatedJobIDs = append(stats.TerminatedJobIDs, job.ID) + } + + return stats +} + +func unhangJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub.Pubsub, jobID uuid.UUID) error { + var lowestLogID int64 + + err := db.InTx(func(db database.Store) error { + locked, err := db.TryAcquireLock(ctx, database.GenLockID(fmt.Sprintf("hang-detector:%s", jobID))) + if err != nil { + return xerrors.Errorf("acquire lock: %w", err) + } + if !locked { + // This error is ignored. + return acquireLockError{} + } + + // Refetch the job while we hold the lock. + job, err := db.GetProvisionerJobByID(ctx, jobID) + if err != nil { + return xerrors.Errorf("get provisioner job: %w", err) + } + + // Check if we should still unhang it. + jobStatus := db2sdk.ProvisionerJobStatus(job) + if jobStatus != codersdk.ProvisionerJobRunning { + return jobInelligibleError{ + Err: xerrors.Errorf("job is not running (status %s)", jobStatus), + } + } + if job.UpdatedAt.After(time.Now().Add(-HungJobDuration)) { + return jobInelligibleError{ + Err: xerrors.New("job has been updated recently"), + } + } + + log.Info(ctx, "detected hung (>5m) provisioner job, forcefully terminating") + + // First, get the latest logs from the build so we can make sure + // our messages are in the latest stage. + logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{ + JobID: job.ID, + CreatedAfter: 0, + }) + if err != nil { + return xerrors.Errorf("get logs for hung job: %w", err) + } + logStage := "" + if len(logs) != 0 { + logStage = logs[len(logs)-1].Stage + } + if logStage == "" { + logStage = "Unknown" + } + + // Insert the messages into the build log. + insertParams := database.InsertProvisionerJobLogsParams{ + JobID: job.ID, + } + now := database.Now() + for i, msg := range HungJobLogMessages { + // Set the created at in a way that ensures each message has + // a unique timestamp so they will be sorted correctly. + insertParams.CreatedAt = append(insertParams.CreatedAt, now.Add(time.Millisecond*time.Duration(i))) + insertParams.Level = append(insertParams.Level, database.LogLevelError) + insertParams.Stage = append(insertParams.Stage, logStage) + insertParams.Source = append(insertParams.Source, database.LogSourceProvisionerDaemon) + insertParams.Output = append(insertParams.Output, msg) + } + newLogs, err := db.InsertProvisionerJobLogs(ctx, insertParams) + if err != nil { + return xerrors.Errorf("insert logs for hung job: %w", err) + } + lowestLogID = newLogs[0].ID + + // Mark the job as failed. + now = database.Now() + err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: job.ID, + UpdatedAt: now, + CompletedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + Error: sql.NullString{ + String: "Coder: Build has been detected as hung for 5 minutes and has been terminated by hang detector.", + Valid: true, + }, + ErrorCode: sql.NullString{ + Valid: false, + }, + }) + if err != nil { + return xerrors.Errorf("mark job as failed: %w", err) + } + + // If the provisioner job is a workspace build, copy the + // provisioner state from the previous build to this workspace + // build. + if job.Type == database.ProvisionerJobTypeWorkspaceBuild { + build, err := db.GetWorkspaceBuildByJobID(ctx, job.ID) + if err != nil { + return xerrors.Errorf("get workspace build for workspace build job by job id: %w", err) + } + + // Only copy the provisioner state if there's no state in + // the current build. + if len(build.ProvisionerState) == 0 { + // Get the previous build if it exists. + prevBuild, err := db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + WorkspaceID: build.WorkspaceID, + BuildNumber: build.BuildNumber - 1, + }) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return xerrors.Errorf("get previous workspace build: %w", err) + } + if err == nil { + _, err = db.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{ + ID: build.ID, + UpdatedAt: database.Now(), + ProvisionerState: prevBuild.ProvisionerState, + Deadline: time.Time{}, + MaxDeadline: time.Time{}, + }) + if err != nil { + return xerrors.Errorf("update workspace build by id: %w", err) + } + } + } + } + + return nil + }, nil) + if err != nil { + return xerrors.Errorf("in tx: %w", err) + } + + // Publish the new log notification to pubsub. Use the lowest log ID + // inserted so the log stream will fetch everything after that point. + data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{ + CreatedAfter: lowestLogID - 1, + EndOfLogs: true, + }) + if err != nil { + return xerrors.Errorf("marshal log notification: %w", err) + } + err = pub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) + if err != nil { + return xerrors.Errorf("publish log notification: %w", err) + } + + return nil +} diff --git a/coderd/unhanger/detector_test.go b/coderd/unhanger/detector_test.go new file mode 100644 index 0000000000000..4f98a82153024 --- /dev/null +++ b/coderd/unhanger/detector_test.go @@ -0,0 +1,724 @@ +package unhanger_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbgen" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/coderd/unhanger" + "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestDetectorNoJobs(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + ) + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + tickCh <- time.Now() + + stats := <-statsCh + require.NoError(t, stats.Error) + require.Empty(t, stats.TerminatedJobIDs) + + detector.Close() + detector.Wait() +} + +func TestDetectorNoHungJobs(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + ) + + // Insert some jobs that are running and haven't been updated in a while, + // but not enough to be considered hung. + now := time.Now() + org := dbgen.Organization(t, db, database.Organization{}) + user := dbgen.User(t, db, database.User{}) + file := dbgen.File(t, db, database.File{}) + for i := 0; i < 5; i++ { + dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: now.Add(-time.Minute * 5), + UpdatedAt: now.Add(-time.Minute * time.Duration(i)), + StartedAt: sql.NullTime{ + Time: now.Add(-time.Minute * 5), + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + }) + } + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + tickCh <- now + + stats := <-statsCh + require.NoError(t, stats.Error) + require.Empty(t, stats.TerminatedJobIDs) + + detector.Close() + detector.Wait() +} + +func TestDetectorHungWorkspaceBuild(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + ) + + var ( + now = time.Now() + twentyMinAgo = now.Add(-time.Minute * 20) + tenMinAgo = now.Add(-time.Minute * 10) + sixMinAgo = now.Add(-time.Minute * 6) + org = dbgen.Organization(t, db, database.Organization{}) + user = dbgen.User(t, db, database.User{}) + file = dbgen.File(t, db, database.File{}) + template = dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + CreatedBy: user.ID, + }) + workspace = dbgen.Workspace(t, db, database.Workspace{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: template.ID, + }) + + // Previous build. + expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`) + previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: twentyMinAgo, + UpdatedAt: twentyMinAgo, + StartedAt: sql.NullTime{ + Time: twentyMinAgo, + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: twentyMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + }) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + TemplateVersionID: templateVersion.ID, + BuildNumber: 1, + ProvisionerState: expectedWorkspaceBuildState, + JobID: previousWorkspaceBuildJob.ID, + }) + + // Current build. + currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: tenMinAgo, + UpdatedAt: sixMinAgo, + StartedAt: sql.NullTime{ + Time: tenMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + }) + currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + TemplateVersionID: templateVersion.ID, + BuildNumber: 2, + JobID: currentWorkspaceBuildJob.ID, + // No provisioner state. + }) + ) + + t.Log("previous job ID: ", previousWorkspaceBuildJob.ID) + t.Log("current job ID: ", currentWorkspaceBuildJob.ID) + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + tickCh <- now + + stats := <-statsCh + require.NoError(t, stats.Error) + require.Len(t, stats.TerminatedJobIDs, 1) + require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0]) + + // Check that the current provisioner job was updated. + job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID) + require.NoError(t, err) + require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second) + require.True(t, job.CompletedAt.Valid) + require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second) + require.True(t, job.Error.Valid) + require.Contains(t, job.Error.String, "Build has been detected as hung") + require.False(t, job.ErrorCode.Valid) + + // Check that the provisioner state was copied. + build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID) + require.NoError(t, err) + require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState) + + detector.Close() + detector.Wait() +} + +func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + ) + + var ( + now = time.Now() + twentyMinAgo = now.Add(-time.Minute * 20) + tenMinAgo = now.Add(-time.Minute * 10) + sixMinAgo = now.Add(-time.Minute * 6) + org = dbgen.Organization(t, db, database.Organization{}) + user = dbgen.User(t, db, database.User{}) + file = dbgen.File(t, db, database.File{}) + template = dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + CreatedBy: user.ID, + }) + workspace = dbgen.Workspace(t, db, database.Workspace{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: template.ID, + }) + + // Previous build. + previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: twentyMinAgo, + UpdatedAt: twentyMinAgo, + StartedAt: sql.NullTime{ + Time: twentyMinAgo, + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: twentyMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + }) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + TemplateVersionID: templateVersion.ID, + BuildNumber: 1, + ProvisionerState: []byte(`{"dean":"NOT cool","colin":"also NOT cool"}`), + JobID: previousWorkspaceBuildJob.ID, + }) + + // Current build. + expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`) + currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: tenMinAgo, + UpdatedAt: sixMinAgo, + StartedAt: sql.NullTime{ + Time: tenMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + }) + currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + TemplateVersionID: templateVersion.ID, + BuildNumber: 2, + JobID: currentWorkspaceBuildJob.ID, + // Should not be overridden. + ProvisionerState: expectedWorkspaceBuildState, + }) + ) + + t.Log("previous job ID: ", previousWorkspaceBuildJob.ID) + t.Log("current job ID: ", currentWorkspaceBuildJob.ID) + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + tickCh <- now + + stats := <-statsCh + require.NoError(t, stats.Error) + require.Len(t, stats.TerminatedJobIDs, 1) + require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0]) + + // Check that the current provisioner job was updated. + job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID) + require.NoError(t, err) + require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second) + require.True(t, job.CompletedAt.Valid) + require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second) + require.True(t, job.Error.Valid) + require.Contains(t, job.Error.String, "Build has been detected as hung") + require.False(t, job.ErrorCode.Valid) + + // Check that the provisioner state was NOT copied. + build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID) + require.NoError(t, err) + require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState) + + detector.Close() + detector.Wait() +} + +func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + ) + + var ( + now = time.Now() + tenMinAgo = now.Add(-time.Minute * 10) + sixMinAgo = now.Add(-time.Minute * 6) + org = dbgen.Organization(t, db, database.Organization{}) + user = dbgen.User(t, db, database.User{}) + file = dbgen.File(t, db, database.File{}) + template = dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + CreatedBy: user.ID, + }) + workspace = dbgen.Workspace(t, db, database.Workspace{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: template.ID, + }) + + // First build. + expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`) + currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: tenMinAgo, + UpdatedAt: sixMinAgo, + StartedAt: sql.NullTime{ + Time: tenMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: []byte("{}"), + }) + currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + TemplateVersionID: templateVersion.ID, + BuildNumber: 1, + JobID: currentWorkspaceBuildJob.ID, + // Should not be overridden. + ProvisionerState: expectedWorkspaceBuildState, + }) + ) + + t.Log("current job ID: ", currentWorkspaceBuildJob.ID) + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + tickCh <- now + + stats := <-statsCh + require.NoError(t, stats.Error) + require.Len(t, stats.TerminatedJobIDs, 1) + require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0]) + + // Check that the current provisioner job was updated. + job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID) + require.NoError(t, err) + require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second) + require.True(t, job.CompletedAt.Valid) + require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second) + require.True(t, job.Error.Valid) + require.Contains(t, job.Error.String, "Build has been detected as hung") + require.False(t, job.ErrorCode.Valid) + + // Check that the provisioner state was NOT updated. + build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID) + require.NoError(t, err) + require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState) + + detector.Close() + detector.Wait() +} + +func TestDetectorHungOtherJobTypes(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + ) + + var ( + now = time.Now() + tenMinAgo = now.Add(-time.Minute * 10) + sixMinAgo = now.Add(-time.Minute * 6) + org = dbgen.Organization(t, db, database.Organization{}) + user = dbgen.User(t, db, database.User{}) + file = dbgen.File(t, db, database.File{}) + + // Template import job. + templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: tenMinAgo, + UpdatedAt: sixMinAgo, + StartedAt: sql.NullTime{ + Time: tenMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Input: []byte("{}"), + }) + + // Template dry-run job. + templateDryRunJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: tenMinAgo, + UpdatedAt: sixMinAgo, + StartedAt: sql.NullTime{ + Time: tenMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: []byte("{}"), + }) + ) + + t.Log("template import job ID: ", templateImportJob.ID) + t.Log("template dry-run job ID: ", templateDryRunJob.ID) + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + tickCh <- now + + stats := <-statsCh + require.NoError(t, stats.Error) + require.Len(t, stats.TerminatedJobIDs, 2) + require.Contains(t, stats.TerminatedJobIDs, templateImportJob.ID) + require.Contains(t, stats.TerminatedJobIDs, templateDryRunJob.ID) + + // Check that the template import job was updated. + job, err := db.GetProvisionerJobByID(ctx, templateImportJob.ID) + require.NoError(t, err) + require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second) + require.True(t, job.CompletedAt.Valid) + require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second) + require.True(t, job.Error.Valid) + require.Contains(t, job.Error.String, "Build has been detected as hung") + require.False(t, job.ErrorCode.Valid) + + // Check that the template dry-run job was updated. + job, err = db.GetProvisionerJobByID(ctx, templateDryRunJob.ID) + require.NoError(t, err) + require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second) + require.True(t, job.CompletedAt.Valid) + require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second) + require.True(t, job.Error.Valid) + require.Contains(t, job.Error.String, "Build has been detected as hung") + require.False(t, job.ErrorCode.Valid) + + detector.Close() + detector.Wait() +} + +func TestDetectorPushesLogs(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + preLogCount int + preLogStage string + expectStage string + }{ + { + name: "WithExistingLogs", + preLogCount: 10, + preLogStage: "Stage Name", + expectStage: "Stage Name", + }, + { + name: "WithExistingLogsNoStage", + preLogCount: 10, + preLogStage: "", + expectStage: "Unknown", + }, + { + name: "WithoutExistingLogs", + preLogCount: 0, + expectStage: "Unknown", + }, + } + + for _, c := range cases { + c := c + + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + ) + + var ( + now = time.Now() + tenMinAgo = now.Add(-time.Minute * 10) + sixMinAgo = now.Add(-time.Minute * 6) + org = dbgen.Organization(t, db, database.Organization{}) + user = dbgen.User(t, db, database.User{}) + file = dbgen.File(t, db, database.File{}) + + // Template import job. + templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: tenMinAgo, + UpdatedAt: sixMinAgo, + StartedAt: sql.NullTime{ + Time: tenMinAgo, + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Input: []byte("{}"), + }) + ) + + t.Log("template import job ID: ", templateImportJob.ID) + + // Insert some logs at the start of the job. + if c.preLogCount > 0 { + insertParams := database.InsertProvisionerJobLogsParams{ + JobID: templateImportJob.ID, + } + for i := 0; i < c.preLogCount; i++ { + insertParams.CreatedAt = append(insertParams.CreatedAt, tenMinAgo.Add(time.Millisecond*time.Duration(i))) + insertParams.Level = append(insertParams.Level, database.LogLevelInfo) + insertParams.Stage = append(insertParams.Stage, c.preLogStage) + insertParams.Source = append(insertParams.Source, database.LogSourceProvisioner) + insertParams.Output = append(insertParams.Output, fmt.Sprintf("Output %d", i)) + } + logs, err := db.InsertProvisionerJobLogs(ctx, insertParams) + require.NoError(t, err) + require.Len(t, logs, 10) + } + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + + // Create pubsub subscription to listen for new log events. + pubsubCalled := make(chan int64, 1) + pubsubCancel, err := pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(templateImportJob.ID), func(ctx context.Context, message []byte) { + defer close(pubsubCalled) + var event provisionersdk.ProvisionerJobLogsNotifyMessage + err := json.Unmarshal(message, &event) + if !assert.NoError(t, err) { + return + } + + assert.True(t, event.EndOfLogs) + pubsubCalled <- event.CreatedAfter + }) + require.NoError(t, err) + defer pubsubCancel() + + tickCh <- now + + stats := <-statsCh + require.NoError(t, stats.Error) + require.Len(t, stats.TerminatedJobIDs, 1) + require.Contains(t, stats.TerminatedJobIDs, templateImportJob.ID) + + after := <-pubsubCalled + + // Get the jobs after the given time and check that they are what we + // expect. + logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{ + JobID: templateImportJob.ID, + CreatedAfter: after, + }) + require.NoError(t, err) + require.Len(t, logs, len(unhanger.HungJobLogMessages)) + for i, log := range logs { + assert.Equal(t, database.LogLevelError, log.Level) + assert.Equal(t, c.expectStage, log.Stage) + assert.Equal(t, database.LogSourceProvisionerDaemon, log.Source) + assert.Equal(t, unhanger.HungJobLogMessages[i], log.Output) + } + + // Double check the full log count. + logs, err = db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{ + JobID: templateImportJob.ID, + CreatedAfter: 0, + }) + require.NoError(t, err) + require.Len(t, logs, c.preLogCount+len(unhanger.HungJobLogMessages)) + + detector.Close() + detector.Wait() + }) + } +} + +func TestDetectorMaxJobsPerRun(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + db, pubsub = dbtestutil.NewDB(t) + log = slogtest.Make(t, nil) + tickCh = make(chan time.Time) + statsCh = make(chan unhanger.Stats) + org = dbgen.Organization(t, db, database.Organization{}) + user = dbgen.User(t, db, database.User{}) + file = dbgen.File(t, db, database.File{}) + ) + + // Create unhanger.MaxJobsPerRun + 1 hung jobs. + now := time.Now() + for i := 0; i < unhanger.MaxJobsPerRun+1; i++ { + dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + CreatedAt: now.Add(-time.Hour), + UpdatedAt: now.Add(-time.Hour), + StartedAt: sql.NullTime{ + Time: now.Add(-time.Hour), + Valid: true, + }, + OrganizationID: org.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Input: []byte("{}"), + }) + } + + detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh) + detector.Start() + tickCh <- now + + // Make sure that only unhanger.MaxJobsPerRun jobs are terminated. + stats := <-statsCh + require.NoError(t, stats.Error) + require.Len(t, stats.TerminatedJobIDs, unhanger.MaxJobsPerRun) + + // Run the detector again and make sure that only the remaining job is + // terminated. + tickCh <- now + stats = <-statsCh + require.NoError(t, stats.Error) + require.Len(t, stats.TerminatedJobIDs, 1) + + detector.Close() + detector.Wait() +} diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 34bc0028da8f4..dc758e5a76242 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -124,6 +124,7 @@ type DeploymentValues struct { // HTTPAddress is a string because it may be set to zero to disable. HTTPAddress clibase.String `json:"http_address,omitempty" typescript:",notnull"` AutobuildPollInterval clibase.Duration `json:"autobuild_poll_interval,omitempty"` + JobHangDetectorInterval clibase.Duration `json:"job_hang_detector_interval,omitempty"` DERP DERP `json:"derp,omitempty" typescript:",notnull"` Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"` Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"` @@ -539,6 +540,16 @@ when required by your organization's security policy.`, Value: &c.AutobuildPollInterval, YAML: "autobuildPollInterval", }, + { + Name: "Job Hang Detector Interval", + Description: "Interval to poll for hung jobs and automatically terminate them.", + Flag: "job-hang-detector-interval", + Env: "CODER_JOB_HANG_DETECTOR_INTERVAL", + Hidden: true, + Default: time.Minute.String(), + Value: &c.JobHangDetectorInterval, + YAML: "jobHangDetectorInterval", + }, httpAddress, tlsBindAddress, { diff --git a/docs/api/general.md b/docs/api/general.md index d395652e2a3c5..1655fb9d2fb00 100644 --- a/docs/api/general.md +++ b/docs/api/general.md @@ -214,6 +214,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ }, "http_address": "string", "in_memory_database": true, + "job_hang_detector_interval": 0, "logging": { "human": "string", "json": "string", diff --git a/docs/api/schemas.md b/docs/api/schemas.md index b6083f03e0736..f332d03968fb1 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -1891,6 +1891,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in }, "http_address": "string", "in_memory_database": true, + "job_hang_detector_interval": 0, "logging": { "human": "string", "json": "string", @@ -2221,6 +2222,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in }, "http_address": "string", "in_memory_database": true, + "job_hang_detector_interval": 0, "logging": { "human": "string", "json": "string", @@ -2400,6 +2402,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in | `git_auth` | [clibase.Struct-array_codersdk_GitAuthConfig](#clibasestruct-array_codersdk_gitauthconfig) | false | | | | `http_address` | string | false | | Http address is a string because it may be set to zero to disable. | | `in_memory_database` | boolean | false | | | +| `job_hang_detector_interval` | integer | false | | | | `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | | | `max_session_expiry` | integer | false | | | | `max_token_lifetime` | integer | false | | | diff --git a/provisioner/terraform/provision.go b/provisioner/terraform/provision.go index 8716c95c84af7..5d249a30b30f6 100644 --- a/provisioner/terraform/provision.go +++ b/provisioner/terraform/provision.go @@ -49,7 +49,7 @@ func (s *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - // Create a separate context for forcefull cancellation not tied to + // Create a separate context for forceful cancellation not tied to // the stream so that we can control when to terminate the process. killCtx, kill := context.WithCancel(context.Background()) defer kill() @@ -57,13 +57,15 @@ func (s *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error { // Ensure processes are eventually cleaned up on graceful // cancellation or disconnect. go func() { - <-stream.Context().Done() + <-ctx.Done() // TODO(mafredri): We should track this provision request as // part of graceful server shutdown procedure. Waiting on a // process here should delay provisioner/coder shutdown. + t := time.NewTimer(s.exitTimeout) + defer t.Stop() select { - case <-time.After(s.exitTimeout): + case <-t.C: kill() case <-killCtx.Done(): } diff --git a/provisioner/terraform/provision_test.go b/provisioner/terraform/provision_test.go index 7879d30fc76ed..1cdc589d4d539 100644 --- a/provisioner/terraform/provision_test.go +++ b/provisioner/terraform/provision_test.go @@ -129,8 +129,7 @@ func TestProvision_Cancel(t *testing.T) { require.NoError(t, err) ctx, api := setupProvisioner(t, &provisionerServeOptions{ - binaryPath: binPath, - exitTimeout: time.Nanosecond, + binaryPath: binPath, }) response, err := api.Provision(ctx) @@ -186,6 +185,75 @@ func TestProvision_Cancel(t *testing.T) { } } +func TestProvision_CancelTimeout(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("This test uses interrupts and is not supported on Windows") + } + + cwd, err := os.Getwd() + require.NoError(t, err) + fakeBin := filepath.Join(cwd, "testdata", "fake_cancel_hang.sh") + + dir := t.TempDir() + binPath := filepath.Join(dir, "terraform") + + // Example: exec /path/to/terrafork_fake_cancel.sh 1.2.1 apply "$@" + content := fmt.Sprintf("#!/bin/sh\nexec %q %s \"$@\"\n", fakeBin, terraform.TerraformVersion.String()) + err = os.WriteFile(binPath, []byte(content), 0o755) //#nosec + require.NoError(t, err) + + ctx, api := setupProvisioner(t, &provisionerServeOptions{ + binaryPath: binPath, + exitTimeout: time.Second, + }) + + response, err := api.Provision(ctx) + require.NoError(t, err) + err = response.Send(&proto.Provision_Request{ + Type: &proto.Provision_Request_Apply{ + Apply: &proto.Provision_Apply{ + Config: &proto.Provision_Config{ + Directory: dir, + Metadata: &proto.Provision_Metadata{}, + }, + }, + }, + }) + require.NoError(t, err) + + for _, line := range []string{"init", "apply_start"} { + LoopStart: + msg, err := response.Recv() + require.NoError(t, err) + + t.Log(msg.Type) + + log := msg.GetLog() + if log == nil { + goto LoopStart + } + require.Equal(t, line, log.Output) + } + + err = response.Send(&proto.Provision_Request{ + Type: &proto.Provision_Request_Cancel{ + Cancel: &proto.Provision_Cancel{}, + }, + }) + require.NoError(t, err) + + for { + msg, err := response.Recv() + require.NoError(t, err) + + if c := msg.GetComplete(); c != nil { + require.Contains(t, c.Error, "killed") + break + } + } +} + func TestProvision(t *testing.T) { t.Parallel() diff --git a/provisioner/terraform/serve.go b/provisioner/terraform/serve.go index 4c3f5e18415db..23f880e6c0418 100644 --- a/provisioner/terraform/serve.go +++ b/provisioner/terraform/serve.go @@ -12,13 +12,10 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/coderd/unhanger" "github.com/coder/coder/provisionersdk" ) -const ( - defaultExitTimeout = 5 * time.Minute -) - type ServeOptions struct { *provisionersdk.ServeOptions @@ -31,14 +28,15 @@ type ServeOptions struct { Tracer trace.Tracer // ExitTimeout defines how long we will wait for a running Terraform - // command to exit (cleanly) if the provision was stopped. This only - // happens when the command is still running after the provision - // stream is closed. If the provision is canceled via RPC, this - // timeout will not be used. + // command to exit (cleanly) if the provision was stopped. This + // happens when the provision is canceled via RPC and when the command is + // still running after the provision stream is closed. // // This is a no-op on Windows where the process can't be interrupted. // - // Default value: 5 minutes. + // Default value: 3 minutes (unhanger.HungJobExitTimeout). This value should + // be kept less than the value that Coder uses to mark hung jobs as failed, + // which is 5 minutes (see unhanger package). ExitTimeout time.Duration } @@ -96,7 +94,7 @@ func Serve(ctx context.Context, options *ServeOptions) error { options.Tracer = trace.NewNoopTracerProvider().Tracer("noop") } if options.ExitTimeout == 0 { - options.ExitTimeout = defaultExitTimeout + options.ExitTimeout = unhanger.HungJobExitTimeout } return provisionersdk.Serve(ctx, &server{ execMut: &sync.Mutex{}, diff --git a/provisioner/terraform/testdata/fake_cancel_hang.sh b/provisioner/terraform/testdata/fake_cancel_hang.sh new file mode 100755 index 0000000000000..c6d29c88c733f --- /dev/null +++ b/provisioner/terraform/testdata/fake_cancel_hang.sh @@ -0,0 +1,41 @@ +#!/bin/sh + +VERSION=$1 +shift 1 + +json_print() { + echo "{\"@level\":\"error\",\"@message\":\"$*\"}" +} + +case "$1" in +version) + cat <<-EOF + { + "terraform_version": "${VERSION}", + "platform": "linux_amd64", + "provider_selections": {}, + "terraform_outdated": false + } + EOF + exit 0 + ;; +init) + echo "init" + exit 0 + ;; +apply) + trap 'json_print interrupt' INT + + json_print apply_start + sleep 10 2>/dev/null >/dev/null + json_print apply_end + + exit 0 + ;; +plan) + echo "plan not supported" + exit 1 + ;; +esac + +exit 0 diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index 1db005aab95e7..7ca5b7f2c4e2c 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -337,6 +337,9 @@ func (r *Runner) sendHeartbeat(ctx context.Context) (*proto.UpdateJobResponse, e } func (r *Runner) update(ctx context.Context, u *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + ctx, span := r.startTrace(ctx, tracing.FuncName()) defer span.End() defer func() { @@ -537,6 +540,7 @@ func (r *Runner) heartbeatRoutine(ctx context.Context) { resp, err := r.sendHeartbeat(ctx) if err != nil { + // Calling Fail starts cancellation so the process will exit. err = r.Fail(ctx, r.failedJobf("send periodic update: %s", err)) if err != nil { r.logger.Error(ctx, "failed to call FailJob", slog.Error(err)) @@ -547,9 +551,9 @@ func (r *Runner) heartbeatRoutine(ctx context.Context) { ticker.Reset(r.updateInterval) continue } - r.logger.Info(ctx, "attempting graceful cancelation") + r.logger.Info(ctx, "attempting graceful cancellation") r.Cancel() - // Hard-cancel the job after a minute of pending cancelation. + // Mark the job as failed after a minute of pending cancellation. timer := time.NewTimer(r.forceCancelInterval) select { case <-timer.C: diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index e1a1f859578c9..49bdcdbb326ff 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -327,6 +327,7 @@ export interface DeploymentValues { readonly redirect_to_access_url?: boolean readonly http_address?: string readonly autobuild_poll_interval?: number + readonly job_hang_detector_interval?: number readonly derp?: DERP readonly prometheus?: PrometheusConfig readonly pprof?: PprofConfig