Skip to content

refactor: claim prebuilt workspace tests #17567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 57 additions & 206 deletions enterprise/coderd/prebuilds/claim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package prebuilds_test
import (
"context"
"database/sql"
"errors"
"slices"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -35,21 +36,25 @@ type storeSpy struct {
claims *atomic.Int32
claimParams *atomic.Pointer[database.ClaimPrebuiltWorkspaceParams]
claimedWorkspace *atomic.Pointer[database.ClaimPrebuiltWorkspaceRow]

// if claimingErr is not nil - error will be returned when ClaimPrebuiltWorkspace is called
claimingErr error
}

func newStoreSpy(db database.Store) *storeSpy {
func newStoreSpy(db database.Store, claimingErr error) *storeSpy {
return &storeSpy{
Store: db,
claims: &atomic.Int32{},
claimParams: &atomic.Pointer[database.ClaimPrebuiltWorkspaceParams]{},
claimedWorkspace: &atomic.Pointer[database.ClaimPrebuiltWorkspaceRow]{},
claimingErr: claimingErr,
}
}

func (m *storeSpy) InTx(fn func(store database.Store) error, opts *database.TxOptions) error {
// Pass spy down into transaction store.
return m.Store.InTx(func(store database.Store) error {
spy := newStoreSpy(store)
spy := newStoreSpy(store, m.claimingErr)
spy.claims = m.claims
spy.claimParams = m.claimParams
spy.claimedWorkspace = m.claimedWorkspace
Expand All @@ -59,6 +64,10 @@ func (m *storeSpy) InTx(fn func(store database.Store) error, opts *database.TxOp
}

func (m *storeSpy) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
if m.claimingErr != nil {
return database.ClaimPrebuiltWorkspaceRow{}, m.claimingErr
}

m.claims.Add(1)
m.claimParams.Store(&arg)
result, err := m.Store.ClaimPrebuiltWorkspace(ctx, arg)
Expand All @@ -68,32 +77,6 @@ func (m *storeSpy) ClaimPrebuiltWorkspace(ctx context.Context, arg database.Clai
return result, err
}

type errorStore struct {
claimingErr error

database.Store
}

func newErrorStore(db database.Store, claimingErr error) *errorStore {
return &errorStore{
Store: db,
claimingErr: claimingErr,
}
}

func (es *errorStore) InTx(fn func(store database.Store) error, opts *database.TxOptions) error {
// Pass failure store down into transaction store.
return es.Store.InTx(func(store database.Store) error {
newES := newErrorStore(store, es.claimingErr)

return fn(newES)
}, opts)
}

func (es *errorStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
return database.ClaimPrebuiltWorkspaceRow{}, es.claimingErr
}

func TestClaimPrebuild(t *testing.T) {
t.Parallel()

Expand All @@ -106,9 +89,13 @@ func TestClaimPrebuild(t *testing.T) {
presetCount = 2
)

unexpectedClaimingError := xerrors.New("unexpected claiming error")

cases := map[string]struct {
expectPrebuildClaimed bool
markPrebuildsClaimable bool
// if claimingErr is not nil - error will be returned when ClaimPrebuiltWorkspace is called
claimingErr error
}{
"no eligible prebuilds to claim": {
expectPrebuildClaimed: false,
Expand All @@ -118,6 +105,17 @@ func TestClaimPrebuild(t *testing.T) {
expectPrebuildClaimed: true,
markPrebuildsClaimable: true,
},

"no claimable prebuilt workspaces error is returned": {
expectPrebuildClaimed: false,
markPrebuildsClaimable: true,
claimingErr: agplprebuilds.ErrNoClaimablePrebuiltWorkspaces,
},
"unexpected claiming error is returned": {
expectPrebuildClaimed: false,
markPrebuildsClaimable: true,
claimingErr: unexpectedClaimingError,
},
}

for name, tc := range cases {
Expand All @@ -129,7 +127,8 @@ func TestClaimPrebuild(t *testing.T) {
// Setup.
ctx := testutil.Context(t, testutil.WaitSuperLong)
db, pubsub := dbtestutil.NewDB(t)
spy := newStoreSpy(db)

spy := newStoreSpy(db, tc.claimingErr)
expectedPrebuildsCount := desiredInstances * presetCount

logger := testutil.Logger(t)
Expand Down Expand Up @@ -225,8 +224,35 @@ func TestClaimPrebuild(t *testing.T) {
TemplateVersionPresetID: presets[0].ID,
})

require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID)
switch {
case tc.claimingErr != nil && errors.Is(tc.claimingErr, agplprebuilds.ErrNoClaimablePrebuiltWorkspaces):
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID)

// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed and we fallback to creating new workspace.
currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
return

case tc.claimingErr != nil && errors.Is(tc.claimingErr, unexpectedClaimingError):
// Then: unexpected error happened and was propagated all the way to the caller
require.Error(t, err)
require.ErrorContains(t, err, unexpectedClaimingError.Error())

// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed.
currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
return

default:
// tc.claimingErr is nil scenario
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID)
}

// at this point we know that tc.claimingErr is nil

// Then: a prebuild should have been claimed.
require.EqualValues(t, spy.claims.Load(), 1)
Expand Down Expand Up @@ -315,181 +341,6 @@ func TestClaimPrebuild(t *testing.T) {
}
}

func TestClaimPrebuild_CheckDifferentErrors(t *testing.T) {
t.Parallel()

if !dbtestutil.WillUsePostgres() {
t.Skip("This test requires postgres")
}

const (
desiredInstances = 1
presetCount = 2

expectedPrebuildsCount = desiredInstances * presetCount
)

cases := map[string]struct {
claimingErr error
checkFn func(
t *testing.T,
ctx context.Context,
store database.Store,
userClient *codersdk.Client,
user codersdk.User,
templateVersionID uuid.UUID,
presetID uuid.UUID,
)
}{
"ErrNoClaimablePrebuiltWorkspaces is returned": {
claimingErr: agplprebuilds.ErrNoClaimablePrebuiltWorkspaces,
checkFn: func(
t *testing.T,
ctx context.Context,
store database.Store,
userClient *codersdk.Client,
user codersdk.User,
templateVersionID uuid.UUID,
presetID uuid.UUID,
) {
// When: a user creates a new workspace with a preset for which prebuilds are configured.
workspaceName := strings.ReplaceAll(testutil.GetRandomName(t), "_", "-")
userWorkspace, err := userClient.CreateUserWorkspace(ctx, user.Username, codersdk.CreateWorkspaceRequest{
TemplateVersionID: templateVersionID,
Name: workspaceName,
TemplateVersionPresetID: presetID,
})

require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID)

// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed and we fallback to creating new workspace.
currentPrebuilds, err := store.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
},
},
"unexpected error during claim is returned": {
claimingErr: xerrors.New("unexpected error during claim"),
checkFn: func(
t *testing.T,
ctx context.Context,
store database.Store,
userClient *codersdk.Client,
user codersdk.User,
templateVersionID uuid.UUID,
presetID uuid.UUID,
) {
// When: a user creates a new workspace with a preset for which prebuilds are configured.
workspaceName := strings.ReplaceAll(testutil.GetRandomName(t), "_", "-")
_, err := userClient.CreateUserWorkspace(ctx, user.Username, codersdk.CreateWorkspaceRequest{
TemplateVersionID: templateVersionID,
Name: workspaceName,
TemplateVersionPresetID: presetID,
})

// Then: unexpected error happened and was propagated all the way to the caller
require.Error(t, err)
require.ErrorContains(t, err, "unexpected error during claim")

// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed.
currentPrebuilds, err := store.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
t.Parallel()

// Setup.
ctx := testutil.Context(t, testutil.WaitSuperLong)
db, pubsub := dbtestutil.NewDB(t)
errorStore := newErrorStore(db, tc.claimingErr)

logger := testutil.Logger(t)
client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
IncludeProvisionerDaemon: true,
Database: errorStore,
Pubsub: pubsub,
},

EntitlementsUpdateInterval: time.Second,
})

reconciler := prebuilds.NewStoreReconciler(errorStore, pubsub, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), api.PrometheusRegistry)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(errorStore)
api.AGPL.PrebuildsClaimer.Store(&claimer)

version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, templateWithAgentAndPresetsWithPrebuilds(desiredInstances))
_ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
presets, err := client.TemplateVersionPresets(ctx, version.ID)
require.NoError(t, err)
require.Len(t, presets, presetCount)

userClient, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember())

// Given: the reconciliation state is snapshot.
state, err := reconciler.SnapshotState(ctx, errorStore)
require.NoError(t, err)
require.Len(t, state.Presets, presetCount)

// When: a reconciliation is setup for each preset.
for _, preset := range presets {
ps, err := state.FilterByPreset(preset.ID)
require.NoError(t, err)
require.NotNil(t, ps)
actions, err := reconciler.CalculateActions(ctx, *ps)
require.NoError(t, err)
require.NotNil(t, actions)

require.NoError(t, reconciler.ReconcilePreset(ctx, *ps))
}

// Given: a set of running, eligible prebuilds eventually starts up.
runningPrebuilds := make(map[uuid.UUID]database.GetRunningPrebuiltWorkspacesRow, desiredInstances*presetCount)
require.Eventually(t, func() bool {
rows, err := errorStore.GetRunningPrebuiltWorkspaces(ctx)
if err != nil {
return false
}

for _, row := range rows {
runningPrebuilds[row.CurrentPresetID.UUID] = row

agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, row.ID)
if err != nil {
return false
}

// Workspaces are eligible once its agent is marked "ready".
for _, agent := range agents {
err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: sql.NullTime{Time: time.Now().Add(time.Hour), Valid: true},
ReadyAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
})
if err != nil {
return false
}
}
}

t.Logf("found %d running prebuilds so far, want %d", len(runningPrebuilds), expectedPrebuildsCount)

return len(runningPrebuilds) == expectedPrebuildsCount
}, testutil.WaitSuperLong, testutil.IntervalSlow)

tc.checkFn(t, ctx, errorStore, userClient, user, version.ID, presets[0].ID)
})
}
}

func templateWithAgentAndPresetsWithPrebuilds(desiredInstances int32) *echo.Responses {
return &echo.Responses{
Parse: echo.ParseComplete,
Expand Down
Loading