From 1e532458488af85d639a8fc06783266ef55e965d Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 8 Sep 2023 14:53:05 +0400 Subject: [PATCH 1/2] chore: add Acquirer to provisionerdserver pkg Signed-off-by: Spike Curtis --- coderd/provisionerdserver/acquirer.go | 481 ++++++++++++++++ coderd/provisionerdserver/acquirer_test.go | 519 ++++++++++++++++++ .../provisionerdserver/provisionerdserver.go | 24 + enterprise/provisionerd/remoteprovisioners.go | 2 +- 4 files changed, 1025 insertions(+), 1 deletion(-) create mode 100644 coderd/provisionerdserver/acquirer.go create mode 100644 coderd/provisionerdserver/acquirer_test.go diff --git a/coderd/provisionerdserver/acquirer.go b/coderd/provisionerdserver/acquirer.go new file mode 100644 index 0000000000000..1ee723fb8d4ce --- /dev/null +++ b/coderd/provisionerdserver/acquirer.go @@ -0,0 +1,481 @@ +package provisionerdserver + +import ( + "context" + "database/sql" + "encoding/json" + "strings" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/uuid" + "golang.org/x/exp/slices" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/database/pubsub" +) + +const ( + EventJobPosted = "provisioner_job_posted" + dbMaxBackoff = 10 * time.Second + backupPollDuration = 30 * time.Second +) + +// Acquirer is shared among multiple routines that need to call +// database.Store.AcquireProvisionerJob. The callers that acquire jobs are called "acquirees". The +// goal is to minimize polling the database (i.e. lower our average query rate) and simplify the +// acquiree's logic by handling retrying the database if a job is not available at the time of the +// call. +// +// When multiple acquirees share a set of provisioner types and tags, we define them as part of the +// same "domain". Only one acquiree from each domain may query the database at a time. If the +// database returns no jobs for that acquiree, the entire domain waits until the Acquirer is +// notified over the pubsub of a new job acceptable to the domain. +// +// As a backup to pubsub notifications, each domain is allowed to query periodically once every 30s. +// This ensures jobs are not stuck permanently if the service that created them fails to publish +// (e.g. a crash). +type Acquirer struct { + ctx context.Context + logger slog.Logger + store AcquirerStore + ps pubsub.Pubsub + + mu sync.Mutex + q map[dKey]domain + + // testing only + backupPollDuration time.Duration +} + +type AcquirerOption func(*Acquirer) + +func TestingBackupPollDuration(dur time.Duration) AcquirerOption { + return func(a *Acquirer) { + a.backupPollDuration = dur + } +} + +// AcquirerStore is the subset of database.Store that the Acquirer needs +type AcquirerStore interface { + AcquireProvisionerJob(context.Context, database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) +} + +func NewAcquirer(ctx context.Context, logger slog.Logger, store AcquirerStore, ps pubsub.Pubsub, + opts ...AcquirerOption, +) *Acquirer { + a := &Acquirer{ + ctx: ctx, + logger: logger, + store: store, + ps: ps, + q: make(map[dKey]domain), + backupPollDuration: backupPollDuration, + } + for _, opt := range opts { + opt(a) + } + a.subscribe() + return a +} + +// AcquireJob acquires a job with one of the given provisioner types and compatible +// tags from the database. The call blocks until a job is acquired, the context is +// done, or the database returns an error _other_ than that no jobs are available. +// If no jobs are available, this method handles retrying as appropriate. +func (a *Acquirer) AcquireJob( + ctx context.Context, worker uuid.UUID, pt []database.ProvisionerType, tags Tags, +) ( + retJob database.ProvisionerJob, retErr error, +) { + logger := a.logger.With( + slog.F("worker_id", worker), + slog.F("provisioner_types", pt), + slog.F("tags", tags)) + logger.Debug(ctx, "acquiring job") + dk := domainKey(pt, tags) + dbTags, err := tags.ToJSON() + if err != nil { + return database.ProvisionerJob{}, err + } + // buffer of 1 so that cancel doesn't deadlock while writing to the channel + clearance := make(chan struct{}, 1) + //nolint:gocritic // Provisionerd has specific authz rules. + principal := dbauthz.AsProvisionerd(ctx) + for { + a.want(pt, tags, clearance) + select { + case <-ctx.Done(): + err := ctx.Err() + logger.Debug(ctx, "acquiring job canceled", slog.Error(err)) + internalError := a.cancel(dk, clearance) + if internalError != nil { + // internalError takes precedence + return database.ProvisionerJob{}, internalError + } + return database.ProvisionerJob{}, err + case <-clearance: + logger.Debug(ctx, "got clearance to call database") + job, err := a.store.AcquireProvisionerJob(principal, database.AcquireProvisionerJobParams{ + StartedAt: sql.NullTime{ + Time: dbtime.Now(), + Valid: true, + }, + WorkerID: uuid.NullUUID{ + UUID: worker, + Valid: true, + }, + Types: pt, + Tags: dbTags, + }) + if xerrors.Is(err, sql.ErrNoRows) { + logger.Debug(ctx, "no job available") + continue + } + // we are not going to retry, so signal we are done + internalError := a.done(dk, clearance) + if internalError != nil { + // internal error takes precedence + return database.ProvisionerJob{}, internalError + } + if err != nil { + logger.Warn(ctx, "error attempting to acquire job", slog.Error(err)) + return database.ProvisionerJob{}, xerrors.Errorf("failed to acquire job: %w", err) + } + logger.Debug(ctx, "successfully acquired job") + return job, nil + } + } +} + +// want signals that an acquiree wants clearance to query for a job with the given dKey. +func (a *Acquirer) want(pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) { + dk := domainKey(pt, tags) + a.mu.Lock() + defer a.mu.Unlock() + cleared := false + d, ok := a.q[dk] + if !ok { + ctx, cancel := context.WithCancel(a.ctx) + d = domain{ + ctx: ctx, + cancel: cancel, + a: a, + key: dk, + pt: pt, + tags: tags, + acquirees: make(map[chan<- struct{}]*acquiree), + } + a.q[dk] = d + go d.poll(a.backupPollDuration) + // this is a new request for this dKey, so is cleared. + cleared = true + } + w, ok := d.acquirees[clearance] + if !ok { + w = &acquiree{clearance: clearance} + d.acquirees[clearance] = w + } + // pending means that we got a job posting for this dKey while we were + // querying, so we should clear this acquiree to retry another time. + if w.pending { + cleared = true + w.pending = false + } + w.inProgress = cleared + if cleared { + // this won't block because clearance is buffered. + clearance <- struct{}{} + } +} + +// cancel signals that an acquiree no longer wants clearance to query. Any error returned is a serious internal error +// indicating that integrity of the internal state is corrupted by a code bug. +func (a *Acquirer) cancel(dk dKey, clearance chan<- struct{}) error { + a.mu.Lock() + defer a.mu.Unlock() + d, ok := a.q[dk] + if !ok { + // this is a code error, as something removed the dKey early, or cancel + // was called twice. + err := xerrors.New("canceled non-existent job acquisition") + a.logger.Critical(a.ctx, "internal error", slog.Error(err)) + return err + } + w, ok := d.acquirees[clearance] + if !ok { + // this is a code error, as something removed the dKey early, or cancel + // was called twice. + err := xerrors.New("canceled non-existent job acquisition") + a.logger.Critical(a.ctx, "internal error", slog.Error(err)) + return err + } + delete(d.acquirees, clearance) + if w.inProgress && len(d.acquirees) > 0 { + // this one canceled before querying, so give another acquiree a chance + // instead + for _, other := range d.acquirees { + if other.inProgress { + err := xerrors.New("more than one acquiree in progress for same key") + a.logger.Critical(a.ctx, "internal error", slog.Error(err)) + return err + } + other.inProgress = true + other.clearance <- struct{}{} + break // just one + } + } + if len(d.acquirees) == 0 { + d.cancel() + delete(a.q, dk) + } + return nil +} + +// done signals that the acquiree has completed acquiring a job (usually successfully, but we also get this call if +// there is a database error other than ErrNoRows). Any error returned is a serious internal error indicating that +// integrity of the internal state is corrupted by a code bug. +func (a *Acquirer) done(dk dKey, clearance chan struct{}) error { + a.mu.Lock() + defer a.mu.Unlock() + d, ok := a.q[dk] + if !ok { + // this is a code error, as something removed the dKey early, or done + // was called twice. + err := xerrors.New("done with non-existent job acquisition") + a.logger.Critical(a.ctx, "internal error", slog.Error(err)) + return err + } + w, ok := d.acquirees[clearance] + if !ok { + // this is a code error, as something removed the dKey early, or done + // was called twice. + err := xerrors.New("canceled non-existent job acquisition") + a.logger.Critical(a.ctx, "internal error", slog.Error(err)) + return err + } + if !w.inProgress { + err := xerrors.New("done acquiree was not in progress") + a.logger.Critical(a.ctx, "internal error", slog.Error(err)) + return err + } + delete(d.acquirees, clearance) + if len(d.acquirees) == 0 { + d.cancel() + delete(a.q, dk) + return nil + } + // in the mainline, this means that the acquiree successfully got a job. + // if any others are waiting, clear one of them to try to get a job next so + // that we process the jobs until there are no more acquirees or the database + // is empty of jobs meeting our criteria + for _, other := range d.acquirees { + if other.inProgress { + err := xerrors.New("more than one acquiree in progress for same key") + a.logger.Critical(a.ctx, "internal error", slog.Error(err)) + return err + } + other.inProgress = true + other.clearance <- struct{}{} + break // just one + } + return nil +} + +func (a *Acquirer) subscribe() { + subscribed := make(chan struct{}) + go func() { + defer close(subscribed) + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, a.ctx) + var cancel context.CancelFunc + err := backoff.Retry(func() error { + cancelFn, err := a.ps.SubscribeWithErr(EventJobPosted, a.jobPosted) + if err != nil { + a.logger.Warn(a.ctx, "failed to subscribe to job postings", slog.Error(err)) + return err + } + cancel = cancelFn + return nil + }, bkoff) + if err != nil { + if a.ctx.Err() == nil { + a.logger.Error(a.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + } + return + } + defer cancel() + bkoff.Reset() + a.logger.Debug(a.ctx, "subscribed to job postings") + + // unblock the outer function from returning + subscribed <- struct{}{} + + // hold subscriptions open until context is canceled + <-a.ctx.Done() + }() + <-subscribed +} + +func (a *Acquirer) jobPosted(ctx context.Context, message []byte, err error) { + if xerrors.Is(err, pubsub.ErrDroppedMessages) { + a.logger.Warn(a.ctx, "pubsub may have dropped job postings") + a.clearOrPendAll() + return + } + if err != nil { + a.logger.Warn(a.ctx, "unhandled pubsub error", slog.Error(err)) + return + } + posting := JobPosting{} + err = json.Unmarshal(message, &posting) + if err != nil { + a.logger.Error(a.ctx, "unable to parse job posting", + slog.F("message", string(message)), + slog.Error(err), + ) + return + } + a.logger.Debug(ctx, "got job posting", slog.F("posting", posting)) + + a.mu.Lock() + defer a.mu.Unlock() + for _, d := range a.q { + if d.contains(posting) { + a.clearOrPendLocked(d) + // we only need to wake up a single domain since there is only one + // new job available + return + } + } +} + +func (a *Acquirer) clearOrPendAll() { + a.mu.Lock() + defer a.mu.Unlock() + for _, d := range a.q { + a.clearOrPendLocked(d) + } +} + +func (a *Acquirer) clearOrPend(d domain) { + a.mu.Lock() + defer a.mu.Unlock() + if len(d.acquirees) == 0 { + // this can happen if the domain is removed right around the time the + // backup poll (which calls this function) triggers. Nothing to do + // since there are no acquirees. + return + } + a.clearOrPendLocked(d) +} + +func (*Acquirer) clearOrPendLocked(d domain) { + // MUST BE CALLED HOLDING THE a.mu LOCK + var nominee *acquiree + for _, w := range d.acquirees { + if nominee == nil { + nominee = w + } + // acquiree in progress always takes precedence, since we don't want to + // wake up more than one acquiree per dKey at a time. + if w.inProgress { + nominee = w + break + } + } + if nominee.inProgress { + nominee.pending = true + return + } + nominee.inProgress = true + nominee.clearance <- struct{}{} +} + +type dKey string + +func domainKey(pt []database.ProvisionerType, tags Tags) dKey { + sb := strings.Builder{} + for _, t := range pt { + _, _ = sb.WriteString(string(t)) + _ = sb.WriteByte(0x00) + } + _ = sb.WriteByte(0x00) + var keys []string + for k := range tags { + keys = append(keys, k) + } + slices.Sort(keys) + for _, k := range keys { + _, _ = sb.WriteString(k) + _ = sb.WriteByte(0x00) + _, _ = sb.WriteString(tags[k]) + _ = sb.WriteByte(0x00) + } + return dKey(sb.String()) +} + +// acquiree represents a specific client of Acquirer that wants to acquire a job +type acquiree struct { + clearance chan<- struct{} + // inProgress is true when the acquiree was granted clearance and a query + // is possibly in progress. + inProgress bool + // pending is true if we get a job posting while a query is in progress, so + // that we know to try again, even if we didn't get a job on the query. + pending bool +} + +// domain represents a set of acquirees with the same provisioner types and +// tags. Acquirees in the same domain are restricted such that only one queries +// the database at a time. +type domain struct { + ctx context.Context + cancel context.CancelFunc + a *Acquirer + key dKey + pt []database.ProvisionerType + tags Tags + acquirees map[chan<- struct{}]*acquiree +} + +func (d domain) contains(p JobPosting) bool { + if !slices.Contains(d.pt, p.ProvisionerType) { + return false + } + for k, v := range p.Tags { + dv, ok := d.tags[k] + if !ok { + return false + } + if v != dv { + return false + } + } + return true +} + +func (d domain) poll(dur time.Duration) { + tkr := time.NewTicker(dur) + defer tkr.Stop() + for { + select { + case <-d.ctx.Done(): + return + case <-tkr.C: + d.a.clearOrPend(d) + } + } +} + +type JobPosting struct { + ProvisionerType database.ProvisionerType `json:"type"` + Tags map[string]string `json:"tags"` +} diff --git a/coderd/provisionerdserver/acquirer_test.go b/coderd/provisionerdserver/acquirer_test.go new file mode 100644 index 0000000000000..00ff82dfd3b96 --- /dev/null +++ b/coderd/provisionerdserver/acquirer_test.go @@ -0,0 +1,519 @@ +package provisionerdserver_test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/exp/slices" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/provisionerdserver" + "github.com/coder/coder/v2/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +// TestAcquirer_Store tests that a database.Store is accepted as a provisionerdserver.AcquirerStore +func TestAcquirer_Store(t *testing.T) { + t.Parallel() + db := dbfake.New() + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + _ = provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps) +} + +func TestAcquirer_Single(t *testing.T) { + t.Parallel() + fs := newFakeOrderedStore() + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps) + + workerID := uuid.New() + pt := []database.ProvisionerType{database.ProvisionerTypeEcho} + tags := provisionerdserver.Tags{ + "foo": "bar", + } + acquiree := newTestAcquiree(t, workerID, pt, tags) + jobID := uuid.New() + go func() { + err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + assert.NoError(t, err) + }() + acquiree.startAcquire(ctx, uut) + job := acquiree.success(ctx) + require.Equal(t, jobID, job.ID) + require.Len(t, fs.params, 1) + require.Equal(t, workerID, fs.params[0].WorkerID.UUID) +} + +// TestAcquirer_MultipleSameDomain tests multiple acquirees with the same provisioners and tags +func TestAcquirer_MultipleSameDomain(t *testing.T) { + t.Parallel() + fs := newFakeOrderedStore() + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps) + + acquirees := make([]*testAcquiree, 0, 10) + jobIDs := make(map[uuid.UUID]bool) + workerIDs := make(map[uuid.UUID]bool) + pt := []database.ProvisionerType{database.ProvisionerTypeEcho} + tags := provisionerdserver.Tags{ + "foo": "bar", + } + for i := 0; i < 10; i++ { + wID := uuid.New() + workerIDs[wID] = true + a := newTestAcquiree(t, wID, pt, tags) + acquirees = append(acquirees, a) + a.startAcquire(ctx, uut) + } + go func() { + for i := 0; i < 10; i++ { + jobID := uuid.New() + jobIDs[jobID] = true + err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + assert.NoError(t, err) + } + }() + gotJobIDs := make(map[uuid.UUID]bool) + for i := 0; i < 10; i++ { + j := acquirees[i].success(ctx) + gotJobIDs[j.ID] = true + } + require.Equal(t, jobIDs, gotJobIDs) + require.Len(t, fs.overlaps, 0) + gotWorkerCalls := make(map[uuid.UUID]bool) + for _, params := range fs.params { + gotWorkerCalls[params.WorkerID.UUID] = true + } + require.Equal(t, workerIDs, gotWorkerCalls) +} + +// TestAcquirer_WaitsOnNoJobs tests that after a call that returns no jobs, Acquirer waits for a new +// job posting before retrying +func TestAcquirer_WaitsOnNoJobs(t *testing.T) { + t.Parallel() + fs := newFakeOrderedStore() + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps) + + workerID := uuid.New() + pt := []database.ProvisionerType{database.ProvisionerTypeEcho} + tags := provisionerdserver.Tags{ + "foo": "bar", + } + acquiree := newTestAcquiree(t, workerID, pt, tags) + jobID := uuid.New() + go func() { + err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows) + assert.NoError(t, err) + err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + assert.NoError(t, err) + }() + acquiree.startAcquire(ctx, uut) + require.Eventually(t, func() bool { + fs.mu.Lock() + defer fs.mu.Unlock() + return len(fs.params) == 1 + }, testutil.WaitShort, testutil.IntervalFast) + acquiree.requireBlocked() + + // First send in some with incompatible tags & types + postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{ + "cool": "tapes", + "strong": "bad", + }) + postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{ + "foo": "fighters", + }) + postJob(t, ps, database.ProvisionerTypeTerraform, provisionerdserver.Tags{ + "foo": "bar", + }) + acquiree.requireBlocked() + + // compatible tags + postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{}) + job := acquiree.success(ctx) + require.Equal(t, jobID, job.ID) +} + +// TestAcquirer_RetriesPending tests that if we get a job posting while a db call is in progress +// we retry to acquire a job immediately, even if the first call returned no jobs. We want this +// behavior since the query that found no jobs could have resolved before the job was posted, but +// the query result could reach us later than the posting over the pubsub. +func TestAcquirer_RetriesPending(t *testing.T) { + t.Parallel() + fs := newFakeOrderedStore() + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps) + + workerID := uuid.New() + pt := []database.ProvisionerType{database.ProvisionerTypeEcho} + tags := provisionerdserver.Tags{ + "foo": "bar", + } + acquiree := newTestAcquiree(t, workerID, pt, tags) + jobID := uuid.New() + + acquiree.startAcquire(ctx, uut) + require.Eventually(t, func() bool { + fs.mu.Lock() + defer fs.mu.Unlock() + return len(fs.params) == 1 + }, testutil.WaitShort, testutil.IntervalFast) + + // First call to DB is in progress. Send in posting + postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{}) + // there is a race between the posting being processed and the DB call + // returning. In either case we should retry, but we're trying to hit the + // case where the posting is processed first, so sleep a little bit to give + // it a chance. + time.Sleep(testutil.IntervalMedium) + + // Now, when first DB call returns ErrNoRows we retry. + err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows) + require.NoError(t, err) + err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + require.NoError(t, err) + + job := acquiree.success(ctx) + require.Equal(t, jobID, job.ID) +} + +// TestAcquirer_DifferentDomains tests that acquirees with different tags don't block each other +func TestAcquirer_DifferentDomains(t *testing.T) { + t.Parallel() + fs := newFakeTaggedStore(t) + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + pt := []database.ProvisionerType{database.ProvisionerTypeEcho} + worker0 := uuid.New() + tags0 := provisionerdserver.Tags{ + "worker": "0", + } + acquiree0 := newTestAcquiree(t, worker0, pt, tags0) + worker1 := uuid.New() + tags1 := provisionerdserver.Tags{ + "worker": "1", + } + acquiree1 := newTestAcquiree(t, worker1, pt, tags1) + jobID := uuid.New() + fs.jobs = []database.ProvisionerJob{ + {ID: jobID, Provisioner: database.ProvisionerTypeEcho, Tags: database.StringMap{"worker": "1"}}, + } + + uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps) + + ctx0, cancel0 := context.WithCancel(ctx) + defer cancel0() + acquiree0.startAcquire(ctx0, uut) + select { + case params := <-fs.params: + require.Equal(t, worker0, params.WorkerID.UUID) + case <-ctx.Done(): + t.Fatal("timed out waiting for call to database from worker0") + } + acquiree0.requireBlocked() + + // worker1 should not be blocked by worker0, as they are different tags + acquiree1.startAcquire(ctx, uut) + job := acquiree1.success(ctx) + require.Equal(t, jobID, job.ID) + + cancel0() + acquiree0.requireCanceled(ctx) +} + +func TestAcquirer_BackupPoll(t *testing.T) { + t.Parallel() + fs := newFakeOrderedStore() + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut := provisionerdserver.NewAcquirer( + ctx, logger.Named("acquirer"), fs, ps, + provisionerdserver.TestingBackupPollDuration(testutil.IntervalMedium), + ) + + workerID := uuid.New() + pt := []database.ProvisionerType{database.ProvisionerTypeEcho} + tags := provisionerdserver.Tags{ + "foo": "bar", + } + acquiree := newTestAcquiree(t, workerID, pt, tags) + jobID := uuid.New() + go func() { + err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows) + assert.NoError(t, err) + err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + assert.NoError(t, err) + }() + acquiree.startAcquire(ctx, uut) + job := acquiree.success(ctx) + require.Equal(t, jobID, job.ID) +} + +// TestAcquirer_UnblockOnCancel tests that a canceled call doesn't block a call +// from the same domain. +func TestAcquirer_UnblockOnCancel(t *testing.T) { + t.Parallel() + fs := newFakeOrderedStore() + ps := pubsub.NewInMemory() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + pt := []database.ProvisionerType{database.ProvisionerTypeEcho} + worker0 := uuid.New() + tags := provisionerdserver.Tags{ + "foo": "bar", + } + acquiree0 := newTestAcquiree(t, worker0, pt, tags) + worker1 := uuid.New() + acquiree1 := newTestAcquiree(t, worker1, pt, tags) + jobID := uuid.New() + + uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps) + + // queue up 2 responses --- we may not need both, since acquiree0 will + // usually cancel before calling, but cancel is async, so it might call. + for i := 0; i < 2; i++ { + err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + require.NoError(t, err) + } + + ctx0, cancel0 := context.WithCancel(ctx) + cancel0() + acquiree0.startAcquire(ctx0, uut) + acquiree1.startAcquire(ctx, uut) + job := acquiree1.success(ctx) + require.Equal(t, jobID, job.ID) +} + +func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) { + msg, err := json.Marshal(provisionerdserver.JobPosting{ + ProvisionerType: pt, + Tags: tags, + }) + require.NoError(t, err) + err = ps.Publish(provisionerdserver.EventJobPosted, msg) + require.NoError(t, err) +} + +// fakeOrderedStore is a fake store that lets tests send AcquireProvisionerJob +// results in order over a channel, and tests for overlapped calls. +type fakeOrderedStore struct { + jobs chan database.ProvisionerJob + errors chan error + + mu sync.Mutex + params []database.AcquireProvisionerJobParams + + // inflight and overlaps track whether any calls from workers overlap with + // one another + inflight map[uuid.UUID]bool + overlaps [][]uuid.UUID +} + +func newFakeOrderedStore() *fakeOrderedStore { + return &fakeOrderedStore{ + // buffer the channels so that we can queue up lots of responses to + // occur nearly simultaneously + jobs: make(chan database.ProvisionerJob, 100), + errors: make(chan error, 100), + inflight: make(map[uuid.UUID]bool), + } +} + +func (s *fakeOrderedStore) AcquireProvisionerJob( + _ context.Context, params database.AcquireProvisionerJobParams, +) ( + database.ProvisionerJob, error, +) { + s.mu.Lock() + s.params = append(s.params, params) + for workerID := range s.inflight { + s.overlaps = append(s.overlaps, []uuid.UUID{workerID, params.WorkerID.UUID}) + } + s.inflight[params.WorkerID.UUID] = true + s.mu.Unlock() + + job := <-s.jobs + err := <-s.errors + + s.mu.Lock() + delete(s.inflight, params.WorkerID.UUID) + s.mu.Unlock() + + return job, err +} + +func (s *fakeOrderedStore) sendCtx(ctx context.Context, job database.ProvisionerJob, err error) error { + select { + case <-ctx.Done(): + return ctx.Err() + case s.jobs <- job: + // OK + } + select { + case <-ctx.Done(): + return ctx.Err() + case s.errors <- err: + // OK + } + return nil +} + +// fakeTaggedStore is a test store that allows tests to specify which jobs are +// available, and returns them to callers with the appropriate provisioner type +// and tags. It doesn't care about the order. +type fakeTaggedStore struct { + t *testing.T + mu sync.Mutex + jobs []database.ProvisionerJob + params chan database.AcquireProvisionerJobParams +} + +func newFakeTaggedStore(t *testing.T) *fakeTaggedStore { + return &fakeTaggedStore{ + t: t, + params: make(chan database.AcquireProvisionerJobParams, 100), + } +} + +func (s *fakeTaggedStore) AcquireProvisionerJob( + _ context.Context, params database.AcquireProvisionerJobParams, +) ( + database.ProvisionerJob, error, +) { + defer func() { s.params <- params }() + var tags provisionerdserver.Tags + err := json.Unmarshal(params.Tags, &tags) + if !assert.NoError(s.t, err) { + return database.ProvisionerJob{}, err + } + s.mu.Lock() + defer s.mu.Unlock() +jobLoop: + for i, job := range s.jobs { + if !slices.Contains(params.Types, job.Provisioner) { + continue + } + for k, v := range job.Tags { + pv, ok := tags[k] + if !ok { + continue jobLoop + } + if v != pv { + continue jobLoop + } + } + // found a job! + s.jobs = append(s.jobs[:i], s.jobs[i+1:]...) + return job, nil + } + return database.ProvisionerJob{}, sql.ErrNoRows +} + +// testAcquiree is a helper type that handles asynchronously calling AcquireJob +// and asserting whether or not it returns, blocks, or is canceled. +type testAcquiree struct { + t *testing.T + workerID uuid.UUID + pt []database.ProvisionerType + tags provisionerdserver.Tags + ec chan error + jc chan database.ProvisionerJob +} + +func newTestAcquiree(t *testing.T, workerID uuid.UUID, pt []database.ProvisionerType, tags provisionerdserver.Tags) *testAcquiree { + return &testAcquiree{ + t: t, + workerID: workerID, + pt: pt, + tags: tags, + ec: make(chan error, 1), + jc: make(chan database.ProvisionerJob, 1), + } +} + +func (a *testAcquiree) startAcquire(ctx context.Context, uut *provisionerdserver.Acquirer) { + go func() { + j, e := uut.AcquireJob(ctx, a.workerID, a.pt, a.tags) + a.ec <- e + a.jc <- j + }() +} + +func (a *testAcquiree) success(ctx context.Context) database.ProvisionerJob { + select { + case <-ctx.Done(): + a.t.Fatal("timeout waiting for AcquireJob error") + case err := <-a.ec: + require.NoError(a.t, err) + } + select { + case <-ctx.Done(): + a.t.Fatal("timeout waiting for AcquireJob job") + case job := <-a.jc: + return job + } + // unhittable + return database.ProvisionerJob{} +} + +func (a *testAcquiree) requireBlocked() { + select { + case <-a.ec: + a.t.Fatal("AcquireJob should block") + default: + // OK + } +} + +func (a *testAcquiree) requireCanceled(ctx context.Context) { + select { + case err := <-a.ec: + require.ErrorIs(a.t, err, context.Canceled) + case <-ctx.Done(): + a.t.Fatal("timed out waiting for AcquireJob") + } + select { + case job := <-a.jc: + require.Equal(a.t, uuid.Nil, job.ID) + case <-ctx.Done(): + a.t.Fatal("timed out waiting for AcquireJob") + } +} diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 05249d65986a4..38556ba810098 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -79,6 +79,30 @@ type server struct { TimeNowFn func() time.Time } +// We use the null byte (0x00) in generating a canonical map key for tags, so +// it cannot be used in the tag keys or values. + +var ErrorTagsContainNullByte = xerrors.New("tags cannot contain the null byte (0x00)") + +type Tags map[string]string + +func (t Tags) ToJSON() (json.RawMessage, error) { + r, err := json.Marshal(t) + if err != nil { + return nil, err + } + return r, err +} + +func (t Tags) Valid() error { + for k, v := range t { + if slices.Contains([]byte(k), 0x00) || slices.Contains([]byte(v), 0x00) { + return ErrorTagsContainNullByte + } + } + return nil +} + func NewServer( accessURL *url.URL, id uuid.UUID, diff --git a/enterprise/provisionerd/remoteprovisioners.go b/enterprise/provisionerd/remoteprovisioners.go index c56459ef3109d..26c93322e662a 100644 --- a/enterprise/provisionerd/remoteprovisioners.go +++ b/enterprise/provisionerd/remoteprovisioners.go @@ -40,7 +40,7 @@ import ( // version; right now, only the unit tests implement this interface. type Executor interface { // Execute a provisioner that connects back to the remoteConnector. errCh - // allows signalling of errors asynchronously and is closed on completion + // allows signaling of errors asynchronously and is closed on completion // with no error. Execute( ctx context.Context, From 1ceac512331283ed449e3348f76dc6f12b2a017b Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 13 Sep 2023 15:23:14 +0400 Subject: [PATCH 2/2] code review improvements & fixes Signed-off-by: Spike Curtis --- coderd/provisionerdserver/acquirer.go | 31 ++++++++++------ coderd/provisionerdserver/acquirer_test.go | 41 +++++++++------------- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/coderd/provisionerdserver/acquirer.go b/coderd/provisionerdserver/acquirer.go index 1ee723fb8d4ce..7fb759faa2612 100644 --- a/coderd/provisionerdserver/acquirer.go +++ b/coderd/provisionerdserver/acquirer.go @@ -21,8 +21,9 @@ import ( ) const ( - EventJobPosted = "provisioner_job_posted" - dbMaxBackoff = 10 * time.Second + EventJobPosted = "provisioner_job_posted" + dbMaxBackoff = 10 * time.Second + // backPollDuration is the period for the backup polling described in Acquirer comment backupPollDuration = 30 * time.Second ) @@ -201,17 +202,17 @@ func (a *Acquirer) cancel(dk dKey, clearance chan<- struct{}) error { defer a.mu.Unlock() d, ok := a.q[dk] if !ok { - // this is a code error, as something removed the dKey early, or cancel + // this is a code error, as something removed the domain early, or cancel // was called twice. - err := xerrors.New("canceled non-existent job acquisition") + err := xerrors.New("cancel for domain that doesn't exist") a.logger.Critical(a.ctx, "internal error", slog.Error(err)) return err } w, ok := d.acquirees[clearance] if !ok { - // this is a code error, as something removed the dKey early, or cancel + // this is a code error, as something removed the acquiree early, or cancel // was called twice. - err := xerrors.New("canceled non-existent job acquisition") + err := xerrors.New("cancel for an acquiree that doesn't exist") a.logger.Critical(a.ctx, "internal error", slog.Error(err)) return err } @@ -245,9 +246,9 @@ func (a *Acquirer) done(dk dKey, clearance chan struct{}) error { defer a.mu.Unlock() d, ok := a.q[dk] if !ok { - // this is a code error, as something removed the dKey early, or done + // this is a code error, as something removed the domain early, or done // was called twice. - err := xerrors.New("done with non-existent job acquisition") + err := xerrors.New("done for a domain that doesn't exist") a.logger.Critical(a.ctx, "internal error", slog.Error(err)) return err } @@ -255,7 +256,7 @@ func (a *Acquirer) done(dk dKey, clearance chan struct{}) error { if !ok { // this is a code error, as something removed the dKey early, or done // was called twice. - err := xerrors.New("canceled non-existent job acquisition") + err := xerrors.New("done for an acquiree that doesn't exist") a.logger.Critical(a.ctx, "internal error", slog.Error(err)) return err } @@ -401,9 +402,19 @@ func (*Acquirer) clearOrPendLocked(d domain) { type dKey string +// domainKey generates a canonical map key for the given provisioner types and +// tags. It uses the null byte (0x00) as a delimiter because it is an +// unprintable control character and won't show up in any "reasonable" set of +// string tags, even in non-Latin scripts. It is important that Tags are +// validated not to contain this control character prior to use. func domainKey(pt []database.ProvisionerType, tags Tags) dKey { + // make a copy of pt before sorting, so that we don't mutate the original + // slice or underlying array. + pts := make([]database.ProvisionerType, len(pt)) + copy(pts, pt) + slices.Sort(pts) sb := strings.Builder{} - for _, t := range pt { + for _, t := range pts { _, _ = sb.WriteString(string(t)) _ = sb.WriteByte(0x00) } diff --git a/coderd/provisionerdserver/acquirer_test.go b/coderd/provisionerdserver/acquirer_test.go index 00ff82dfd3b96..6d72da5f7ffe4 100644 --- a/coderd/provisionerdserver/acquirer_test.go +++ b/coderd/provisionerdserver/acquirer_test.go @@ -54,10 +54,8 @@ func TestAcquirer_Single(t *testing.T) { } acquiree := newTestAcquiree(t, workerID, pt, tags) jobID := uuid.New() - go func() { - err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) - assert.NoError(t, err) - }() + err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + require.NoError(t, err) acquiree.startAcquire(ctx, uut) job := acquiree.success(ctx) require.Equal(t, jobID, job.ID) @@ -89,14 +87,12 @@ func TestAcquirer_MultipleSameDomain(t *testing.T) { acquirees = append(acquirees, a) a.startAcquire(ctx, uut) } - go func() { - for i := 0; i < 10; i++ { - jobID := uuid.New() - jobIDs[jobID] = true - err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) - assert.NoError(t, err) - } - }() + for i := 0; i < 10; i++ { + jobID := uuid.New() + jobIDs[jobID] = true + err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + require.NoError(t, err) + } gotJobIDs := make(map[uuid.UUID]bool) for i := 0; i < 10; i++ { j := acquirees[i].success(ctx) @@ -129,12 +125,10 @@ func TestAcquirer_WaitsOnNoJobs(t *testing.T) { } acquiree := newTestAcquiree(t, workerID, pt, tags) jobID := uuid.New() - go func() { - err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows) - assert.NoError(t, err) - err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) - assert.NoError(t, err) - }() + err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows) + require.NoError(t, err) + err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + require.NoError(t, err) acquiree.startAcquire(ctx, uut) require.Eventually(t, func() bool { fs.mu.Lock() @@ -274,12 +268,10 @@ func TestAcquirer_BackupPoll(t *testing.T) { } acquiree := newTestAcquiree(t, workerID, pt, tags) jobID := uuid.New() - go func() { - err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows) - assert.NoError(t, err) - err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) - assert.NoError(t, err) - }() + err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows) + require.NoError(t, err) + err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil) + require.NoError(t, err) acquiree.startAcquire(ctx, uut) job := acquiree.success(ctx) require.Equal(t, jobID, job.ID) @@ -323,6 +315,7 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) { } func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) { + t.Helper() msg, err := json.Marshal(provisionerdserver.JobPosting{ ProvisionerType: pt, Tags: tags,