diff --git a/coderd/database/db.go b/coderd/database/db.go index ae2c31a566cb3..0f923a861efb4 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -28,6 +28,7 @@ type Store interface { wrapper Ping(ctx context.Context) (time.Duration, error) + PGLocks(ctx context.Context) (PGLocks, error) InTx(func(Store) error, *TxOptions) error } @@ -48,13 +49,26 @@ type DBTX interface { GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error } +func WithSerialRetryCount(count int) func(*sqlQuerier) { + return func(q *sqlQuerier) { + q.serialRetryCount = count + } +} + // New creates a new database store using a SQL database connection. -func New(sdb *sql.DB) Store { +func New(sdb *sql.DB, opts ...func(*sqlQuerier)) Store { dbx := sqlx.NewDb(sdb, "postgres") - return &sqlQuerier{ + q := &sqlQuerier{ db: dbx, sdb: dbx, + // This is an arbitrary number. + serialRetryCount: 3, + } + + for _, opt := range opts { + opt(q) } + return q } // TxOptions is used to pass some execution metadata to the callers. @@ -104,6 +118,10 @@ type querier interface { type sqlQuerier struct { sdb *sqlx.DB db DBTX + + // serialRetryCount is the number of times to retry a transaction + // if it fails with a serialization error. + serialRetryCount int } func (*sqlQuerier) Wrappers() []string { @@ -143,11 +161,9 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *TxOptions) error { // If we are in a transaction already, the parent InTx call will handle the retry. // We do not want to duplicate those retries. if !inTx && sqlOpts.Isolation == sql.LevelSerializable { - // This is an arbitrarily chosen number. - const retryAmount = 3 var err error attempts := 0 - for attempts = 0; attempts < retryAmount; attempts++ { + for attempts = 0; attempts < q.serialRetryCount; attempts++ { txOpts.executionCount++ err = q.runTx(function, sqlOpts) if err == nil { @@ -203,3 +219,10 @@ func (q *sqlQuerier) runTx(function func(Store) error, txOpts *sql.TxOptions) er } return nil } + +func safeString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index ae6b307b3e7d3..9bf98aade03c4 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -603,6 +603,10 @@ func (q *querier) Ping(ctx context.Context) (time.Duration, error) { return q.db.Ping(ctx) } +func (q *querier) PGLocks(ctx context.Context) (database.PGLocks, error) { + return q.db.PGLocks(ctx) +} + // InTx runs the given function in a transaction. func (q *querier) InTx(function func(querier database.Store) error, txOpts *database.TxOptions) error { return q.db.InTx(func(tx database.Store) error { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 439cf1bdaec19..ae50309e96d66 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -152,7 +152,10 @@ func TestDBAuthzRecursive(t *testing.T) { for i := 2; i < method.Type.NumIn(); i++ { ins = append(ins, reflect.New(method.Type.In(i)).Elem()) } - if method.Name == "InTx" || method.Name == "Ping" || method.Name == "Wrappers" { + if method.Name == "InTx" || + method.Name == "Ping" || + method.Name == "Wrappers" || + method.Name == "PGLocks" { continue } // Log the name of the last method, so if there is a panic, it is diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index df9d551101a25..52e8dd42fea9c 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -34,6 +34,7 @@ var errMatchAny = xerrors.New("match any error") var skipMethods = map[string]string{ "InTx": "Not relevant", "Ping": "Not relevant", + "PGLocks": "Not relevant", "Wrappers": "Not relevant", "AcquireLock": "Not relevant", "TryAcquireLock": "Not relevant", diff --git a/coderd/database/dbfake/builder.go b/coderd/database/dbfake/builder.go new file mode 100644 index 0000000000000..6803374e72445 --- /dev/null +++ b/coderd/database/dbfake/builder.go @@ -0,0 +1,127 @@ +package dbfake + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/testutil" +) + +type OrganizationBuilder struct { + t *testing.T + db database.Store + seed database.Organization + allUsersAllowance int32 + members []uuid.UUID + groups map[database.Group][]uuid.UUID +} + +func Organization(t *testing.T, db database.Store) OrganizationBuilder { + return OrganizationBuilder{ + t: t, + db: db, + members: []uuid.UUID{}, + groups: make(map[database.Group][]uuid.UUID), + } +} + +type OrganizationResponse struct { + Org database.Organization + AllUsersGroup database.Group + Members []database.OrganizationMember + Groups []database.Group +} + +func (b OrganizationBuilder) EveryoneAllowance(allowance int) OrganizationBuilder { + //nolint: revive // returns modified struct + b.allUsersAllowance = int32(allowance) + return b +} + +func (b OrganizationBuilder) Seed(seed database.Organization) OrganizationBuilder { + //nolint: revive // returns modified struct + b.seed = seed + return b +} + +func (b OrganizationBuilder) Members(users ...database.User) OrganizationBuilder { + for _, u := range users { + //nolint: revive // returns modified struct + b.members = append(b.members, u.ID) + } + return b +} + +func (b OrganizationBuilder) Group(seed database.Group, members ...database.User) OrganizationBuilder { + //nolint: revive // returns modified struct + b.groups[seed] = []uuid.UUID{} + for _, u := range members { + //nolint: revive // returns modified struct + b.groups[seed] = append(b.groups[seed], u.ID) + } + return b +} + +func (b OrganizationBuilder) Do() OrganizationResponse { + org := dbgen.Organization(b.t, b.db, b.seed) + + ctx := testutil.Context(b.t, testutil.WaitShort) + //nolint:gocritic // builder code needs perms + ctx = dbauthz.AsSystemRestricted(ctx) + everyone, err := b.db.InsertAllUsersGroup(ctx, org.ID) + require.NoError(b.t, err) + + if b.allUsersAllowance > 0 { + everyone, err = b.db.UpdateGroupByID(ctx, database.UpdateGroupByIDParams{ + Name: everyone.Name, + DisplayName: everyone.DisplayName, + AvatarURL: everyone.AvatarURL, + QuotaAllowance: b.allUsersAllowance, + ID: everyone.ID, + }) + require.NoError(b.t, err) + } + + members := make([]database.OrganizationMember, 0) + if len(b.members) > 0 { + for _, u := range b.members { + newMem := dbgen.OrganizationMember(b.t, b.db, database.OrganizationMember{ + UserID: u, + OrganizationID: org.ID, + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + Roles: nil, + }) + members = append(members, newMem) + } + } + + groups := make([]database.Group, 0) + if len(b.groups) > 0 { + for g, users := range b.groups { + g.OrganizationID = org.ID + group := dbgen.Group(b.t, b.db, g) + groups = append(groups, group) + + for _, u := range users { + dbgen.GroupMember(b.t, b.db, database.GroupMemberTable{ + UserID: u, + GroupID: group.ID, + }) + } + } + } + + return OrganizationResponse{ + Org: org, + AllUsersGroup: everyone, + Members: members, + Groups: groups, + } +} diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 69419b98c79b1..d369d8a023ba9 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -407,6 +407,8 @@ func OrganizationMember(t testing.TB, db database.Store, orig database.Organizat } func Group(t testing.TB, db database.Store, orig database.Group) database.Group { + t.Helper() + name := takeFirst(orig.Name, testutil.GetRandomName(t)) group, err := db.InsertGroup(genCtx, database.InsertGroupParams{ ID: takeFirst(orig.ID, uuid.New()), diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 4f54598744dd0..e38c3e107013f 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -339,6 +339,10 @@ func (*FakeQuerier) Ping(_ context.Context) (time.Duration, error) { return 0, nil } +func (*FakeQuerier) PGLocks(_ context.Context) (database.PGLocks, error) { + return []database.PGLock{}, nil +} + func (tx *fakeTx) AcquireLock(_ context.Context, id int64) error { if _, ok := tx.FakeQuerier.locks[id]; ok { return xerrors.Errorf("cannot acquire lock %d: already held", id) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 7e74aab3b9de0..e1cfec5bac9ca 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -66,6 +66,13 @@ func (m queryMetricsStore) Ping(ctx context.Context) (time.Duration, error) { return duration, err } +func (m queryMetricsStore) PGLocks(ctx context.Context) (database.PGLocks, error) { + start := time.Now() + locks, err := m.s.PGLocks(ctx) + m.queryLatencies.WithLabelValues("PGLocks").Observe(time.Since(start).Seconds()) + return locks, err +} + func (m queryMetricsStore) InTx(f func(database.Store) error, options *database.TxOptions) error { return m.dbMetrics.InTx(f, options) } diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index ffc9ab79f777e..27b398a062051 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4299,6 +4299,21 @@ func (mr *MockStoreMockRecorder) OrganizationMembers(arg0, arg1 any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OrganizationMembers", reflect.TypeOf((*MockStore)(nil).OrganizationMembers), arg0, arg1) } +// PGLocks mocks base method. +func (m *MockStore) PGLocks(arg0 context.Context) (database.PGLocks, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PGLocks", arg0) + ret0, _ := ret[0].(database.PGLocks) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PGLocks indicates an expected call of PGLocks. +func (mr *MockStoreMockRecorder) PGLocks(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PGLocks", reflect.TypeOf((*MockStore)(nil).PGLocks), arg0) +} + // Ping mocks base method. func (m *MockStore) Ping(arg0 context.Context) (time.Duration, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 327d880f69648..bc8c571795629 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -135,7 +135,8 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { if o.dumpOnFailure { t.Cleanup(func() { DumpOnFailure(t, connectionURL) }) } - db = database.New(sqlDB) + // Unit tests should not retry serial transaction failures. + db = database.New(sqlDB, database.WithSerialRetryCount(1)) ps, err = pubsub.New(context.Background(), o.logger, sqlDB, connectionURL) require.NoError(t, err) diff --git a/coderd/database/dbtestutil/tx.go b/coderd/database/dbtestutil/tx.go new file mode 100644 index 0000000000000..15be63dc35aeb --- /dev/null +++ b/coderd/database/dbtestutil/tx.go @@ -0,0 +1,73 @@ +package dbtestutil + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" +) + +type DBTx struct { + database.Store + mu sync.Mutex + done chan error + finalErr chan error +} + +// StartTx starts a transaction and returns a DBTx object. This allows running +// 2 transactions concurrently in a test more easily. +// Example: +// +// a := StartTx(t, db, opts) +// b := StartTx(t, db, opts) +// +// a.GetUsers(...) +// b.GetUsers(...) +// +// require.NoError(t, a.Done() +func StartTx(t *testing.T, db database.Store, opts *database.TxOptions) *DBTx { + done := make(chan error) + finalErr := make(chan error) + txC := make(chan database.Store) + + go func() { + t.Helper() + once := sync.Once{} + count := 0 + + err := db.InTx(func(store database.Store) error { + // InTx can be retried + once.Do(func() { + txC <- store + }) + count++ + if count > 1 { + // If you recursively call InTx, then don't use this. + t.Logf("InTx called more than once: %d", count) + assert.NoError(t, xerrors.New("InTx called more than once, this is not allowed with the StartTx helper")) + } + + <-done + // Just return nil. The caller should be checking their own errors. + return nil + }, opts) + finalErr <- err + }() + + txStore := <-txC + close(txC) + + return &DBTx{Store: txStore, done: done, finalErr: finalErr} +} + +// Done can only be called once. If you call it twice, it will panic. +func (tx *DBTx) Done() error { + tx.mu.Lock() + defer tx.mu.Unlock() + + close(tx.done) + return <-tx.finalErr +} diff --git a/coderd/database/pglocks.go b/coderd/database/pglocks.go new file mode 100644 index 0000000000000..85e1644b3825c --- /dev/null +++ b/coderd/database/pglocks.go @@ -0,0 +1,119 @@ +package database + +import ( + "context" + "fmt" + "reflect" + "sort" + "strings" + "time" + + "github.com/jmoiron/sqlx" + + "github.com/coder/coder/v2/coderd/util/slice" +) + +// PGLock docs see: https://www.postgresql.org/docs/current/view-pg-locks.html#VIEW-PG-LOCKS +type PGLock struct { + // LockType see: https://www.postgresql.org/docs/current/monitoring-stats.html#WAIT-EVENT-LOCK-TABLE + LockType *string `db:"locktype"` + Database *string `db:"database"` // oid + Relation *string `db:"relation"` // oid + RelationName *string `db:"relation_name"` + Page *int `db:"page"` + Tuple *int `db:"tuple"` + VirtualXID *string `db:"virtualxid"` + TransactionID *string `db:"transactionid"` // xid + ClassID *string `db:"classid"` // oid + ObjID *string `db:"objid"` // oid + ObjSubID *int `db:"objsubid"` + VirtualTransaction *string `db:"virtualtransaction"` + PID int `db:"pid"` + Mode *string `db:"mode"` + Granted bool `db:"granted"` + FastPath *bool `db:"fastpath"` + WaitStart *time.Time `db:"waitstart"` +} + +func (l PGLock) Equal(b PGLock) bool { + // Lazy, but hope this works + return reflect.DeepEqual(l, b) +} + +func (l PGLock) String() string { + granted := "granted" + if !l.Granted { + granted = "waiting" + } + var details string + switch safeString(l.LockType) { + case "relation": + details = "" + case "page": + details = fmt.Sprintf("page=%d", *l.Page) + case "tuple": + details = fmt.Sprintf("page=%d tuple=%d", *l.Page, *l.Tuple) + case "virtualxid": + details = "waiting to acquire virtual tx id lock" + default: + details = "???" + } + return fmt.Sprintf("%d-%5s [%s] %s/%s/%s: %s", + l.PID, + safeString(l.TransactionID), + granted, + safeString(l.RelationName), + safeString(l.LockType), + safeString(l.Mode), + details, + ) +} + +// PGLocks returns a list of all locks in the database currently in use. +func (q *sqlQuerier) PGLocks(ctx context.Context) (PGLocks, error) { + rows, err := q.sdb.QueryContext(ctx, ` + SELECT + relation::regclass AS relation_name, + * + FROM pg_locks; + `) + if err != nil { + return nil, err + } + + defer rows.Close() + + var locks []PGLock + err = sqlx.StructScan(rows, &locks) + if err != nil { + return nil, err + } + + return locks, err +} + +type PGLocks []PGLock + +func (l PGLocks) String() string { + // Try to group things together by relation name. + sort.Slice(l, func(i, j int) bool { + return safeString(l[i].RelationName) < safeString(l[j].RelationName) + }) + + var out strings.Builder + for i, lock := range l { + if i != 0 { + _, _ = out.WriteString("\n") + } + _, _ = out.WriteString(lock.String()) + } + return out.String() +} + +// Difference returns the difference between two sets of locks. +// This is helpful to determine what changed between the two sets. +func (l PGLocks) Difference(to PGLocks) (new PGLocks, removed PGLocks) { + return slice.SymmetricDifferenceFunc(l, to, func(a, b PGLock) bool { + return a.Equal(b) + }) +} diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 45cbef3f5e1d8..db0972debdb85 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -6736,23 +6736,33 @@ const getQuotaConsumedForUser = `-- name: GetQuotaConsumedForUser :one WITH latest_builds AS ( SELECT DISTINCT ON - (workspace_id) id, - workspace_id, - daily_cost + (wb.workspace_id) wb.workspace_id, + wb.daily_cost FROM workspace_builds wb + -- This INNER JOIN prevents a seq scan of the workspace_builds table. + -- Limit the rows to the absolute minimum required, which is all workspaces + -- in a given organization for a given user. +INNER JOIN + workspaces on wb.workspace_id = workspaces.id +WHERE + workspaces.owner_id = $1 AND + workspaces.organization_id = $2 ORDER BY - workspace_id, - created_at DESC + wb.workspace_id, + wb.created_at DESC ) SELECT coalesce(SUM(daily_cost), 0)::BIGINT FROM workspaces -JOIN latest_builds ON +INNER JOIN latest_builds ON latest_builds.workspace_id = workspaces.id -WHERE NOT - deleted AND +WHERE + NOT deleted AND + -- We can likely remove these conditions since we check above. + -- But it does not hurt to be defensive and make sure future query changes + -- do not break anything. workspaces.owner_id = $1 AND workspaces.organization_id = $2 ` diff --git a/coderd/database/queries/quotas.sql b/coderd/database/queries/quotas.sql index 48f9209783e4e..7ab6189dfe8a1 100644 --- a/coderd/database/queries/quotas.sql +++ b/coderd/database/queries/quotas.sql @@ -18,23 +18,33 @@ INNER JOIN groups ON WITH latest_builds AS ( SELECT DISTINCT ON - (workspace_id) id, - workspace_id, - daily_cost + (wb.workspace_id) wb.workspace_id, + wb.daily_cost FROM workspace_builds wb + -- This INNER JOIN prevents a seq scan of the workspace_builds table. + -- Limit the rows to the absolute minimum required, which is all workspaces + -- in a given organization for a given user. +INNER JOIN + workspaces on wb.workspace_id = workspaces.id +WHERE + workspaces.owner_id = @owner_id AND + workspaces.organization_id = @organization_id ORDER BY - workspace_id, - created_at DESC + wb.workspace_id, + wb.created_at DESC ) SELECT coalesce(SUM(daily_cost), 0)::BIGINT FROM workspaces -JOIN latest_builds ON +INNER JOIN latest_builds ON latest_builds.workspace_id = workspaces.id -WHERE NOT - deleted AND +WHERE + NOT deleted AND + -- We can likely remove these conditions since we check above. + -- But it does not hurt to be defensive and make sure future query changes + -- do not break anything. workspaces.owner_id = @owner_id AND workspaces.organization_id = @organization_id ; diff --git a/enterprise/coderd/workspacequota_test.go b/enterprise/coderd/workspacequota_test.go index ac4a77eaec8b4..13142f11e5717 100644 --- a/enterprise/coderd/workspacequota_test.go +++ b/enterprise/coderd/workspacequota_test.go @@ -2,11 +2,13 @@ package coderd_test import ( "context" + "database/sql" "encoding/json" "fmt" "net/http" "sync" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -14,6 +16,11 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" @@ -295,6 +302,497 @@ func TestWorkspaceQuota(t *testing.T) { }) } +// nolint:paralleltest,tparallel // Tests must run serially +func TestWorkspaceSerialization(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("Serialization errors only occur in postgres") + } + + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + otherUser := dbgen.User(t, db, database.User{}) + + org := dbfake.Organization(t, db). + EveryoneAllowance(20). + Members(user, otherUser). + Group(database.Group{ + QuotaAllowance: 10, + }, user, otherUser). + Group(database.Group{ + QuotaAllowance: 10, + }, user). + Do() + + otherOrg := dbfake.Organization(t, db). + EveryoneAllowance(20). + Members(user, otherUser). + Group(database.Group{ + QuotaAllowance: 10, + }, user, otherUser). + Group(database.Group{ + QuotaAllowance: 10, + }, user). + Do() + + // TX mixing tests. **DO NOT** run these in parallel. + // The goal here is to mess around with different ordering of + // transactions and queries. + + // UpdateBuildDeadline bumps a workspace deadline while doing a quota + // commit to the same workspace build. + // + // Note: This passes if the interrupt is run before 'GetQuota()' + // Passing orders: + // - BeginTX -> Bump! -> GetQuota -> GetAllowance -> UpdateCost -> EndTx + // - BeginTX -> GetQuota -> GetAllowance -> UpdateCost -> Bump! -> EndTx + t.Run("UpdateBuildDeadline", func(t *testing.T) { + t.Log("Expected to fail. As long as quota & deadline are on the same " + + " table and affect the same row, this will likely always fail.") + + // +------------------------------+------------------+ + // | Begin Tx | | + // +------------------------------+------------------+ + // | GetQuota(user) | | + // +------------------------------+------------------+ + // | | BumpDeadline(w1) | + // +------------------------------+------------------+ + // | GetAllowance(user) | | + // +------------------------------+------------------+ + // | UpdateWorkspaceBuildCost(w1) | | + // +------------------------------+------------------+ + // | CommitTx() | | + // +------------------------------+------------------+ + // pq: could not serialize access due to concurrent update + ctx := testutil.Context(t, testutil.WaitLong) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + bumpDeadline := func() { + err := db.InTx(func(db database.Store) error { + err := db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ + Deadline: dbtime.Now(), + MaxDeadline: dbtime.Now(), + UpdatedAt: dbtime.Now(), + ID: myWorkspace.Build.ID, + }) + return err + }, &database.TxOptions{ + Isolation: sql.LevelSerializable, + }) + assert.NoError(t, err) + } + + // Start TX + // Run order + + quota := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + quota.GetQuota(ctx, t) // Step 1 + bumpDeadline() // Interrupt + quota.GetAllowance(ctx, t) // Step 2 + + err := quota.DBTx.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{ + ID: myWorkspace.Build.ID, + DailyCost: 10, + }) // Step 3 + require.ErrorContains(t, err, "could not serialize access due to concurrent update") + // End commit + require.ErrorContains(t, quota.Done(), "failed transaction") + }) + + // UpdateOtherBuildDeadline bumps a user's other workspace deadline + // while doing a quota commit. + t.Run("UpdateOtherBuildDeadline", func(t *testing.T) { + // +------------------------------+------------------+ + // | Begin Tx | | + // +------------------------------+------------------+ + // | GetQuota(user) | | + // +------------------------------+------------------+ + // | | BumpDeadline(w2) | + // +------------------------------+------------------+ + // | GetAllowance(user) | | + // +------------------------------+------------------+ + // | UpdateWorkspaceBuildCost(w1) | | + // +------------------------------+------------------+ + // | CommitTx() | | + // +------------------------------+------------------+ + // Works! + ctx := testutil.Context(t, testutil.WaitLong) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + // Use the same template + otherWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }). + Seed(database.WorkspaceBuild{ + TemplateVersionID: myWorkspace.TemplateVersion.ID, + }). + Do() + + bumpDeadline := func() { + err := db.InTx(func(db database.Store) error { + err := db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ + Deadline: dbtime.Now(), + MaxDeadline: dbtime.Now(), + UpdatedAt: dbtime.Now(), + ID: otherWorkspace.Build.ID, + }) + return err + }, &database.TxOptions{ + Isolation: sql.LevelSerializable, + }) + assert.NoError(t, err) + } + + // Start TX + // Run order + + quota := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + quota.GetQuota(ctx, t) // Step 1 + bumpDeadline() // Interrupt + quota.GetAllowance(ctx, t) // Step 2 + quota.UpdateWorkspaceBuildCostByID(ctx, t, 10) // Step 3 + // End commit + require.NoError(t, quota.Done()) + }) + + t.Run("ActivityBump", func(t *testing.T) { + t.Log("Expected to fail. As long as quota & deadline are on the same " + + " table and affect the same row, this will likely always fail.") + // +---------------------+----------------------------------+ + // | W1 Quota Tx | | + // +---------------------+----------------------------------+ + // | Begin Tx | | + // +---------------------+----------------------------------+ + // | GetQuota(w1) | | + // +---------------------+----------------------------------+ + // | GetAllowance(w1) | | + // +---------------------+----------------------------------+ + // | | ActivityBump(w1) | + // +---------------------+----------------------------------+ + // | UpdateBuildCost(w1) | | + // +---------------------+----------------------------------+ + // | CommitTx() | | + // +---------------------+----------------------------------+ + // pq: could not serialize access due to concurrent update + ctx := testutil.Context(t, testutil.WaitShort) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }). + Seed(database.WorkspaceBuild{ + // Make sure the bump does something + Deadline: dbtime.Now().Add(time.Hour * -20), + }). + Do() + + one := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + + // Run order + one.GetQuota(ctx, t) + one.GetAllowance(ctx, t) + + err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{ + NextAutostart: time.Now(), + WorkspaceID: myWorkspace.Workspace.ID, + }) + + assert.NoError(t, err) + + err = one.DBTx.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{ + ID: myWorkspace.Build.ID, + DailyCost: 10, + }) + require.ErrorContains(t, err, "could not serialize access due to concurrent update") + + // End commit + assert.ErrorContains(t, one.Done(), "failed transaction") + }) + + t.Run("BumpLastUsedAt", func(t *testing.T) { + // +---------------------+----------------------------------+ + // | W1 Quota Tx | | + // +---------------------+----------------------------------+ + // | Begin Tx | | + // +---------------------+----------------------------------+ + // | GetQuota(w1) | | + // +---------------------+----------------------------------+ + // | GetAllowance(w1) | | + // +---------------------+----------------------------------+ + // | | UpdateWorkspaceLastUsedAt(w1) | + // +---------------------+----------------------------------+ + // | UpdateBuildCost(w1) | | + // +---------------------+----------------------------------+ + // | CommitTx() | | + // +---------------------+----------------------------------+ + ctx := testutil.Context(t, testutil.WaitShort) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + one := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + + // Run order + one.GetQuota(ctx, t) + one.GetAllowance(ctx, t) + + err := db.UpdateWorkspaceLastUsedAt(ctx, database.UpdateWorkspaceLastUsedAtParams{ + ID: myWorkspace.Workspace.ID, + LastUsedAt: dbtime.Now(), + }) + assert.NoError(t, err) + + one.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + // End commit + assert.NoError(t, one.Done()) + }) + + t.Run("UserMod", func(t *testing.T) { + // +---------------------+----------------------------------+ + // | W1 Quota Tx | | + // +---------------------+----------------------------------+ + // | Begin Tx | | + // +---------------------+----------------------------------+ + // | GetQuota(w1) | | + // +---------------------+----------------------------------+ + // | GetAllowance(w1) | | + // +---------------------+----------------------------------+ + // | | RemoveUserFromOrg | + // +---------------------+----------------------------------+ + // | UpdateBuildCost(w1) | | + // +---------------------+----------------------------------+ + // | CommitTx() | | + // +---------------------+----------------------------------+ + // Works! + ctx := testutil.Context(t, testutil.WaitShort) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + var err error + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + one := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + + // Run order + + one.GetQuota(ctx, t) + one.GetAllowance(ctx, t) + + err = db.DeleteOrganizationMember(ctx, database.DeleteOrganizationMemberParams{ + OrganizationID: myWorkspace.Workspace.OrganizationID, + UserID: user.ID, + }) + assert.NoError(t, err) + + one.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + // End commit + assert.NoError(t, one.Done()) + }) + + // QuotaCommit 2 workspaces in different orgs. + // Workspaces do not share templates, owners, or orgs + t.Run("DoubleQuotaUnrelatedWorkspaces", func(t *testing.T) { + // +---------------------+---------------------+ + // | W1 Quota Tx | W2 Quota Tx | + // +---------------------+---------------------+ + // | Begin Tx | | + // +---------------------+---------------------+ + // | | Begin Tx | + // +---------------------+---------------------+ + // | GetQuota(w1) | | + // +---------------------+---------------------+ + // | GetAllowance(w1) | | + // +---------------------+---------------------+ + // | UpdateBuildCost(w1) | | + // +---------------------+---------------------+ + // | | UpdateBuildCost(w2) | + // +---------------------+---------------------+ + // | | GetQuota(w2) | + // +---------------------+---------------------+ + // | | GetAllowance(w2) | + // +---------------------+---------------------+ + // | CommitTx() | | + // +---------------------+---------------------+ + // | | CommitTx() | + // +---------------------+---------------------+ + ctx := testutil.Context(t, testutil.WaitLong) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + myOtherWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: otherOrg.Org.ID, // Different org! + OwnerID: otherUser.ID, + }).Do() + + one := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + two := newCommitter(t, db, myOtherWorkspace.Workspace, myOtherWorkspace.Build) + + // Run order + one.GetQuota(ctx, t) + one.GetAllowance(ctx, t) + + one.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + two.GetQuota(ctx, t) + two.GetAllowance(ctx, t) + two.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + // End commit + assert.NoError(t, one.Done()) + assert.NoError(t, two.Done()) + }) + + // QuotaCommit 2 workspaces in different orgs. + // Workspaces do not share templates or orgs + t.Run("DoubleQuotaUserWorkspacesDiffOrgs", func(t *testing.T) { + // +---------------------+---------------------+ + // | W1 Quota Tx | W2 Quota Tx | + // +---------------------+---------------------+ + // | Begin Tx | | + // +---------------------+---------------------+ + // | | Begin Tx | + // +---------------------+---------------------+ + // | GetQuota(w1) | | + // +---------------------+---------------------+ + // | GetAllowance(w1) | | + // +---------------------+---------------------+ + // | UpdateBuildCost(w1) | | + // +---------------------+---------------------+ + // | | UpdateBuildCost(w2) | + // +---------------------+---------------------+ + // | | GetQuota(w2) | + // +---------------------+---------------------+ + // | | GetAllowance(w2) | + // +---------------------+---------------------+ + // | CommitTx() | | + // +---------------------+---------------------+ + // | | CommitTx() | + // +---------------------+---------------------+ + ctx := testutil.Context(t, testutil.WaitLong) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + myOtherWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: otherOrg.Org.ID, // Different org! + OwnerID: user.ID, + }).Do() + + one := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + two := newCommitter(t, db, myOtherWorkspace.Workspace, myOtherWorkspace.Build) + + // Run order + one.GetQuota(ctx, t) + one.GetAllowance(ctx, t) + + one.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + two.GetQuota(ctx, t) + two.GetAllowance(ctx, t) + two.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + // End commit + assert.NoError(t, one.Done()) + assert.NoError(t, two.Done()) + }) + + // QuotaCommit 2 workspaces in the same org. + // Workspaces do not share templates + t.Run("DoubleQuotaUserWorkspaces", func(t *testing.T) { + t.Log("Setting a new build cost to a workspace in a org affects other " + + "workspaces in the same org. This is expected to fail.") + // +---------------------+---------------------+ + // | W1 Quota Tx | W2 Quota Tx | + // +---------------------+---------------------+ + // | Begin Tx | | + // +---------------------+---------------------+ + // | | Begin Tx | + // +---------------------+---------------------+ + // | GetQuota(w1) | | + // +---------------------+---------------------+ + // | GetAllowance(w1) | | + // +---------------------+---------------------+ + // | UpdateBuildCost(w1) | | + // +---------------------+---------------------+ + // | | UpdateBuildCost(w2) | + // +---------------------+---------------------+ + // | | GetQuota(w2) | + // +---------------------+---------------------+ + // | | GetAllowance(w2) | + // +---------------------+---------------------+ + // | CommitTx() | | + // +---------------------+---------------------+ + // | | CommitTx() | + // +---------------------+---------------------+ + // pq: could not serialize access due to read/write dependencies among transactions + ctx := testutil.Context(t, testutil.WaitLong) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + myWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + myOtherWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.Org.ID, + OwnerID: user.ID, + }).Do() + + one := newCommitter(t, db, myWorkspace.Workspace, myWorkspace.Build) + two := newCommitter(t, db, myOtherWorkspace.Workspace, myOtherWorkspace.Build) + + // Run order + one.GetQuota(ctx, t) + one.GetAllowance(ctx, t) + + one.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + two.GetQuota(ctx, t) + two.GetAllowance(ctx, t) + two.UpdateWorkspaceBuildCostByID(ctx, t, 10) + + // End commit + assert.NoError(t, one.Done()) + assert.ErrorContains(t, two.Done(), "could not serialize access due to read/write dependencies among transactions") + }) +} + func deprecatedQuotaEndpoint(ctx context.Context, client *codersdk.Client, userID string) (codersdk.WorkspaceQuota, error) { res, err := client.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspace-quota/%s", userID), nil) if err != nil { @@ -335,3 +833,65 @@ func applyWithCost(cost int32) []*proto.Response { }, }} } + +// committer does what the CommitQuota does, but allows +// stepping through the actions in the tx and controlling the +// timing. +// This is a nice wrapper to make the tests more concise. +type committer struct { + DBTx *dbtestutil.DBTx + w database.WorkspaceTable + b database.WorkspaceBuild +} + +func newCommitter(t *testing.T, db database.Store, workspace database.WorkspaceTable, build database.WorkspaceBuild) *committer { + quotaTX := dbtestutil.StartTx(t, db, &database.TxOptions{ + Isolation: sql.LevelSerializable, + ReadOnly: false, + }) + return &committer{DBTx: quotaTX, w: workspace, b: build} +} + +// GetQuota touches: +// - workspace_builds +// - workspaces +func (c *committer) GetQuota(ctx context.Context, t *testing.T) int64 { + t.Helper() + + consumed, err := c.DBTx.GetQuotaConsumedForUser(ctx, database.GetQuotaConsumedForUserParams{ + OwnerID: c.w.OwnerID, + OrganizationID: c.w.OrganizationID, + }) + require.NoError(t, err) + return consumed +} + +// GetAllowance touches: +// - group_members_expanded +// - users +// - groups +// - org_members +func (c *committer) GetAllowance(ctx context.Context, t *testing.T) int64 { + t.Helper() + + allowance, err := c.DBTx.GetQuotaAllowanceForUser(ctx, database.GetQuotaAllowanceForUserParams{ + UserID: c.w.OwnerID, + OrganizationID: c.w.OrganizationID, + }) + require.NoError(t, err) + return allowance +} + +func (c *committer) UpdateWorkspaceBuildCostByID(ctx context.Context, t *testing.T, cost int32) bool { + t.Helper() + + err := c.DBTx.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{ + ID: c.b.ID, + DailyCost: cost, + }) + return assert.NoError(t, err) +} + +func (c *committer) Done() error { + return c.DBTx.Done() +}